Add typing hints for MQTT mixins (#80702)

* Add typing hints for MQTT mixins

* Follow up comments

* config_entry is always set

* typing discovery_data - substate None assignment

* Rename `config[CONF_DEVICE]` -> specifications
This commit is contained in:
Jan Bouwhuis 2022-10-24 15:00:37 +02:00 committed by GitHub
parent 64d6d04ade
commit 2f11385627
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 132 additions and 106 deletions

View File

@ -552,7 +552,7 @@ class MqttCover(MqttEntity, CoverEntity):
This method is a coroutine. This method is a coroutine.
""" """
await self.async_publish( await self.async_publish(
self._config.get(CONF_COMMAND_TOPIC), self._config[CONF_COMMAND_TOPIC],
self._config[CONF_PAYLOAD_OPEN], self._config[CONF_PAYLOAD_OPEN],
self._config[CONF_QOS], self._config[CONF_QOS],
self._config[CONF_RETAIN], self._config[CONF_RETAIN],
@ -573,7 +573,7 @@ class MqttCover(MqttEntity, CoverEntity):
This method is a coroutine. This method is a coroutine.
""" """
await self.async_publish( await self.async_publish(
self._config.get(CONF_COMMAND_TOPIC), self._config[CONF_COMMAND_TOPIC],
self._config[CONF_PAYLOAD_CLOSE], self._config[CONF_PAYLOAD_CLOSE],
self._config[CONF_QOS], self._config[CONF_QOS],
self._config[CONF_RETAIN], self._config[CONF_RETAIN],
@ -594,7 +594,7 @@ class MqttCover(MqttEntity, CoverEntity):
This method is a coroutine. This method is a coroutine.
""" """
await self.async_publish( await self.async_publish(
self._config.get(CONF_COMMAND_TOPIC), self._config[CONF_COMMAND_TOPIC],
self._config[CONF_PAYLOAD_STOP], self._config[CONF_PAYLOAD_STOP],
self._config[CONF_QOS], self._config[CONF_QOS],
self._config[CONF_RETAIN], self._config[CONF_RETAIN],
@ -614,7 +614,7 @@ class MqttCover(MqttEntity, CoverEntity):
} }
tilt_payload = self._set_tilt_template(tilt_open_position, variables=variables) tilt_payload = self._set_tilt_template(tilt_open_position, variables=variables)
await self.async_publish( await self.async_publish(
self._config.get(CONF_TILT_COMMAND_TOPIC), self._config[CONF_TILT_COMMAND_TOPIC],
tilt_payload, tilt_payload,
self._config[CONF_QOS], self._config[CONF_QOS],
self._config[CONF_RETAIN], self._config[CONF_RETAIN],
@ -641,7 +641,7 @@ class MqttCover(MqttEntity, CoverEntity):
tilt_closed_position, variables=variables tilt_closed_position, variables=variables
) )
await self.async_publish( await self.async_publish(
self._config.get(CONF_TILT_COMMAND_TOPIC), self._config[CONF_TILT_COMMAND_TOPIC],
tilt_payload, tilt_payload,
self._config[CONF_QOS], self._config[CONF_QOS],
self._config[CONF_RETAIN], self._config[CONF_RETAIN],
@ -670,7 +670,7 @@ class MqttCover(MqttEntity, CoverEntity):
tilt = self._set_tilt_template(tilt, variables=variables) tilt = self._set_tilt_template(tilt, variables=variables)
await self.async_publish( await self.async_publish(
self._config.get(CONF_TILT_COMMAND_TOPIC), self._config[CONF_TILT_COMMAND_TOPIC],
tilt, tilt,
self._config[CONF_QOS], self._config[CONF_QOS],
self._config[CONF_RETAIN], self._config[CONF_RETAIN],
@ -697,7 +697,7 @@ class MqttCover(MqttEntity, CoverEntity):
position = self._set_position_template(position, variables=variables) position = self._set_position_template(position, variables=variables)
await self.async_publish( await self.async_publish(
self._config.get(CONF_SET_POSITION_TOPIC), self._config[CONF_SET_POSITION_TOPIC],
position, position,
self._config[CONF_QOS], self._config[CONF_QOS],
self._config[CONF_RETAIN], self._config[CONF_RETAIN],

View File

@ -7,6 +7,7 @@ import functools
import logging import logging
import re import re
import time import time
from typing import Any
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_DEVICE, CONF_PLATFORM from homeassistant.const import CONF_DEVICE, CONF_PLATFORM
@ -19,7 +20,7 @@ from homeassistant.helpers.dispatcher import (
) )
from homeassistant.helpers.json import json_loads from homeassistant.helpers.json import json_loads
from homeassistant.helpers.service_info.mqtt import MqttServiceInfo from homeassistant.helpers.service_info.mqtt import MqttServiceInfo
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.helpers.typing import DiscoveryInfoType
from homeassistant.loader import async_get_mqtt from homeassistant.loader import async_get_mqtt
from .. import mqtt from .. import mqtt
@ -73,8 +74,8 @@ MQTT_DISCOVERY_DONE = "mqtt_discovery_done_{}"
TOPIC_BASE = "~" TOPIC_BASE = "~"
class MQTTConfig(dict): class MQTTDiscoveryPayload(dict[str, Any]):
"""Dummy class to allow adding attributes.""" """Class to hold and MQTT discovery payload and discovery data."""
discovery_data: DiscoveryInfoType discovery_data: DiscoveryInfoType
@ -96,7 +97,7 @@ async def async_start( # noqa: C901
mqtt_data = get_mqtt_data(hass) mqtt_data = get_mqtt_data(hass)
mqtt_integrations = {} mqtt_integrations = {}
async def async_discovery_message_received(msg): async def async_discovery_message_received(msg) -> None:
"""Process the received message.""" """Process the received message."""
mqtt_data.last_discovery = time.time() mqtt_data.last_discovery = time.time()
payload = msg.payload payload = msg.payload
@ -126,7 +127,7 @@ async def async_start( # noqa: C901
_LOGGER.warning("Unable to parse JSON %s: '%s'", object_id, payload) _LOGGER.warning("Unable to parse JSON %s: '%s'", object_id, payload)
return return
payload = MQTTConfig(payload) payload = MQTTDiscoveryPayload(payload)
for key in list(payload): for key in list(payload):
abbreviated_key = key abbreviated_key = key
@ -195,7 +196,7 @@ async def async_start( # noqa: C901
await async_process_discovery_payload(component, discovery_id, payload) await async_process_discovery_payload(component, discovery_id, payload)
async def async_process_discovery_payload( async def async_process_discovery_payload(
component: str, discovery_id: str, payload: ConfigType component: str, discovery_id: str, payload: MQTTDiscoveryPayload
) -> None: ) -> None:
"""Process the payload of a new discovery.""" """Process the payload of a new discovery."""

View File

@ -34,7 +34,10 @@ from homeassistant.helpers import (
device_registry as dr, device_registry as dr,
entity_registry as er, entity_registry as er,
) )
from homeassistant.helpers.device_registry import EVENT_DEVICE_REGISTRY_UPDATED from homeassistant.helpers.device_registry import (
EVENT_DEVICE_REGISTRY_UPDATED,
DeviceEntry,
)
from homeassistant.helpers.dispatcher import ( from homeassistant.helpers.dispatcher import (
async_dispatcher_connect, async_dispatcher_connect,
async_dispatcher_send, async_dispatcher_send,
@ -50,6 +53,7 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.event import async_track_entity_registry_updated_event from homeassistant.helpers.event import async_track_entity_registry_updated_event
from homeassistant.helpers.issue_registry import IssueSeverity, async_create_issue from homeassistant.helpers.issue_registry import IssueSeverity, async_create_issue
from homeassistant.helpers.json import json_loads from homeassistant.helpers.json import json_loads
from homeassistant.helpers.service_info.mqtt import ReceivePayloadType
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from . import debug_info, subscription from . import debug_info, subscription
@ -74,11 +78,13 @@ from .discovery import (
MQTT_DISCOVERY_DONE, MQTT_DISCOVERY_DONE,
MQTT_DISCOVERY_NEW, MQTT_DISCOVERY_NEW,
MQTT_DISCOVERY_UPDATED, MQTT_DISCOVERY_UPDATED,
MQTTDiscoveryPayload,
clear_discovery_hash, clear_discovery_hash,
set_discovery_hash, set_discovery_hash,
) )
from .models import MqttValueTemplate, PublishPayloadType, ReceiveMessage from .models import MqttValueTemplate, PublishPayloadType, ReceiveMessage
from .subscription import ( from .subscription import (
EntitySubscription,
async_prepare_subscribe_topics, async_prepare_subscribe_topics,
async_subscribe_topics, async_subscribe_topics,
async_unsubscribe_topics, async_unsubscribe_topics,
@ -222,7 +228,7 @@ MQTT_ENTITY_COMMON_SCHEMA = MQTT_AVAILABILITY_SCHEMA.extend(
) )
def warn_for_legacy_schema(domain: str) -> Callable: def warn_for_legacy_schema(domain: str) -> Callable[[ConfigType], ConfigType]:
"""Warn once when a legacy platform schema is used.""" """Warn once when a legacy platform schema is used."""
warned = set() warned = set()
@ -269,8 +275,8 @@ class SetupEntity(Protocol):
hass: HomeAssistant, hass: HomeAssistant,
async_add_entities: AddEntitiesCallback, async_add_entities: AddEntitiesCallback,
config: ConfigType, config: ConfigType,
config_entry: ConfigEntry | None = None, config_entry: ConfigEntry,
discovery_data: dict[str, Any] | None = None, discovery_data: DiscoveryInfoType | None = None,
) -> None: ) -> None:
"""Define setup_entities type.""" """Define setup_entities type."""
@ -294,13 +300,13 @@ async def async_get_platform_config_from_yaml(
async def async_setup_entry_helper( async def async_setup_entry_helper(
hass: HomeAssistant, hass: HomeAssistant,
domain: str, domain: str,
async_setup: partial[Coroutine[HomeAssistant, str, None]], async_setup: partial[Coroutine[Any, Any, None]],
discovery_schema: vol.Schema, discovery_schema: vol.Schema,
) -> None: ) -> None:
"""Set up entity, automation or tag creation dynamically through MQTT discovery.""" """Set up entity, automation or tag creation dynamically through MQTT discovery."""
mqtt_data = get_mqtt_data(hass) mqtt_data = get_mqtt_data(hass)
async def async_discover(discovery_payload): async def async_discover(discovery_payload: MQTTDiscoveryPayload) -> None:
"""Discover and add an MQTT entity, automation or tag.""" """Discover and add an MQTT entity, automation or tag."""
if not mqtt_config_entry_enabled(hass): if not mqtt_config_entry_enabled(hass):
_LOGGER.warning( _LOGGER.warning(
@ -312,10 +318,10 @@ async def async_setup_entry_helper(
return return
discovery_data = discovery_payload.discovery_data discovery_data = discovery_payload.discovery_data
try: try:
config = discovery_schema(discovery_payload) config: DiscoveryInfoType = discovery_schema(discovery_payload)
await async_setup(config, discovery_data=discovery_data) await async_setup(config, discovery_data=discovery_data)
except Exception: except Exception:
discovery_hash = discovery_data[ATTR_DISCOVERY_HASH] discovery_hash: tuple[str, str] = discovery_data[ATTR_DISCOVERY_HASH]
clear_discovery_hash(hass, discovery_hash) clear_discovery_hash(hass, discovery_hash)
async_dispatcher_send( async_dispatcher_send(
hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None
@ -357,7 +363,7 @@ async def async_setup_entry_helper(
async def async_setup_platform_helper( async def async_setup_platform_helper(
hass: HomeAssistant, hass: HomeAssistant,
platform_domain: str, platform_domain: str,
config: ConfigType | DiscoveryInfoType, config: ConfigType,
async_add_entities: AddEntitiesCallback, async_add_entities: AddEntitiesCallback,
async_setup_entities: SetupEntity, async_setup_entities: SetupEntity,
) -> None: ) -> None:
@ -381,7 +387,9 @@ async def async_setup_platform_helper(
await async_setup_entities(hass, async_add_entities, config, config_entry) await async_setup_entities(hass, async_add_entities, config, config_entry)
def init_entity_id_from_config(hass, entity, config, entity_id_format): def init_entity_id_from_config(
hass: HomeAssistant, entity: Entity, config: ConfigType, entity_id_format: str
) -> None:
"""Set entity_id from object_id if defined in config.""" """Set entity_id from object_id if defined in config."""
if CONF_OBJECT_ID in config: if CONF_OBJECT_ID in config:
entity.entity_id = async_generate_entity_id( entity.entity_id = async_generate_entity_id(
@ -394,10 +402,10 @@ class MqttAttributes(Entity):
_attributes_extra_blocked: frozenset[str] = frozenset() _attributes_extra_blocked: frozenset[str] = frozenset()
def __init__(self, config: dict) -> None: def __init__(self, config: ConfigType) -> None:
"""Initialize the JSON attributes mixin.""" """Initialize the JSON attributes mixin."""
self._attributes: dict[str, Any] | None = None self._attributes: dict[str, Any] | None = None
self._attributes_sub_state = None self._attributes_sub_state: dict[str, EntitySubscription] = {}
self._attributes_config = config self._attributes_config = config
async def async_added_to_hass(self) -> None: async def async_added_to_hass(self) -> None:
@ -406,16 +414,16 @@ class MqttAttributes(Entity):
self._attributes_prepare_subscribe_topics() self._attributes_prepare_subscribe_topics()
await self._attributes_subscribe_topics() await self._attributes_subscribe_topics()
def attributes_prepare_discovery_update(self, config: dict): def attributes_prepare_discovery_update(self, config: DiscoveryInfoType) -> None:
"""Handle updated discovery message.""" """Handle updated discovery message."""
self._attributes_config = config self._attributes_config = config
self._attributes_prepare_subscribe_topics() self._attributes_prepare_subscribe_topics()
async def attributes_discovery_update(self, config: dict): async def attributes_discovery_update(self, config: DiscoveryInfoType) -> None:
"""Handle updated discovery message.""" """Handle updated discovery message."""
await self._attributes_subscribe_topics() await self._attributes_subscribe_topics()
def _attributes_prepare_subscribe_topics(self): def _attributes_prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
attr_tpl = MqttValueTemplate( attr_tpl = MqttValueTemplate(
self._attributes_config.get(CONF_JSON_ATTRS_TEMPLATE), entity=self self._attributes_config.get(CONF_JSON_ATTRS_TEMPLATE), entity=self
@ -458,11 +466,11 @@ class MqttAttributes(Entity):
}, },
) )
async def _attributes_subscribe_topics(self): async def _attributes_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await async_subscribe_topics(self.hass, self._attributes_sub_state) await async_subscribe_topics(self.hass, self._attributes_sub_state)
async def async_will_remove_from_hass(self): async def async_will_remove_from_hass(self) -> None:
"""Unsubscribe when removed.""" """Unsubscribe when removed."""
self._attributes_sub_state = async_unsubscribe_topics( self._attributes_sub_state = async_unsubscribe_topics(
self.hass, self._attributes_sub_state self.hass, self._attributes_sub_state
@ -477,11 +485,11 @@ class MqttAttributes(Entity):
class MqttAvailability(Entity): class MqttAvailability(Entity):
"""Mixin used for platforms that report availability.""" """Mixin used for platforms that report availability."""
def __init__(self, config: dict) -> None: def __init__(self, config: ConfigType) -> None:
"""Initialize the availability mixin.""" """Initialize the availability mixin."""
self._availability_sub_state = None self._availability_sub_state: dict[str, EntitySubscription] = {}
self._available: dict = {} self._available: dict[str, str | bool] = {}
self._available_latest = False self._available_latest: bool = False
self._availability_setup_from_config(config) self._availability_setup_from_config(config)
async def async_added_to_hass(self) -> None: async def async_added_to_hass(self) -> None:
@ -498,18 +506,18 @@ class MqttAvailability(Entity):
) )
) )
def availability_prepare_discovery_update(self, config: dict): def availability_prepare_discovery_update(self, config: DiscoveryInfoType) -> None:
"""Handle updated discovery message.""" """Handle updated discovery message."""
self._availability_setup_from_config(config) self._availability_setup_from_config(config)
self._availability_prepare_subscribe_topics() self._availability_prepare_subscribe_topics()
async def availability_discovery_update(self, config: dict): async def availability_discovery_update(self, config: DiscoveryInfoType) -> None:
"""Handle updated discovery message.""" """Handle updated discovery message."""
await self._availability_subscribe_topics() await self._availability_subscribe_topics()
def _availability_setup_from_config(self, config): def _availability_setup_from_config(self, config: ConfigType) -> None:
"""(Re)Setup.""" """(Re)Setup."""
self._avail_topics = {} self._avail_topics: dict[str, dict[str, Any]] = {}
if CONF_AVAILABILITY_TOPIC in config: if CONF_AVAILABILITY_TOPIC in config:
self._avail_topics[config[CONF_AVAILABILITY_TOPIC]] = { self._avail_topics[config[CONF_AVAILABILITY_TOPIC]] = {
CONF_PAYLOAD_AVAILABLE: config[CONF_PAYLOAD_AVAILABLE], CONF_PAYLOAD_AVAILABLE: config[CONF_PAYLOAD_AVAILABLE],
@ -518,6 +526,7 @@ class MqttAvailability(Entity):
} }
if CONF_AVAILABILITY in config: if CONF_AVAILABILITY in config:
avail: dict[str, Any]
for avail in config[CONF_AVAILABILITY]: for avail in config[CONF_AVAILABILITY]:
self._avail_topics[avail[CONF_TOPIC]] = { self._avail_topics[avail[CONF_TOPIC]] = {
CONF_PAYLOAD_AVAILABLE: avail[CONF_PAYLOAD_AVAILABLE], CONF_PAYLOAD_AVAILABLE: avail[CONF_PAYLOAD_AVAILABLE],
@ -533,7 +542,7 @@ class MqttAvailability(Entity):
self._avail_config = config self._avail_config = config
def _availability_prepare_subscribe_topics(self): def _availability_prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
@callback @callback
@ -541,6 +550,7 @@ class MqttAvailability(Entity):
def availability_message_received(msg: ReceiveMessage) -> None: def availability_message_received(msg: ReceiveMessage) -> None:
"""Handle a new received MQTT availability message.""" """Handle a new received MQTT availability message."""
topic = msg.topic topic = msg.topic
payload: ReceivePayloadType
payload = self._avail_topics[topic][CONF_AVAILABILITY_TEMPLATE](msg.payload) payload = self._avail_topics[topic][CONF_AVAILABILITY_TEMPLATE](msg.payload)
if payload == self._avail_topics[topic][CONF_PAYLOAD_AVAILABLE]: if payload == self._avail_topics[topic][CONF_PAYLOAD_AVAILABLE]:
self._available[topic] = True self._available[topic] = True
@ -555,7 +565,7 @@ class MqttAvailability(Entity):
topic: (self._available[topic] if topic in self._available else False) topic: (self._available[topic] if topic in self._available else False)
for topic in self._avail_topics for topic in self._avail_topics
} }
topics = { topics: dict[str, dict[str, Any]] = {
f"availability_{topic}": { f"availability_{topic}": {
"topic": topic, "topic": topic,
"msg_callback": availability_message_received, "msg_callback": availability_message_received,
@ -571,17 +581,17 @@ class MqttAvailability(Entity):
topics, topics,
) )
async def _availability_subscribe_topics(self): async def _availability_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await async_subscribe_topics(self.hass, self._availability_sub_state) await async_subscribe_topics(self.hass, self._availability_sub_state)
@callback @callback
def async_mqtt_connect(self): def async_mqtt_connect(self) -> None:
"""Update state on connection/disconnection to MQTT broker.""" """Update state on connection/disconnection to MQTT broker."""
if not self.hass.is_stopping: if not self.hass.is_stopping:
self.async_write_ha_state() self.async_write_ha_state()
async def async_will_remove_from_hass(self): async def async_will_remove_from_hass(self) -> None:
"""Unsubscribe when removed.""" """Unsubscribe when removed."""
self._availability_sub_state = async_unsubscribe_topics( self._availability_sub_state = async_unsubscribe_topics(
self.hass, self._availability_sub_state self.hass, self._availability_sub_state
@ -628,12 +638,12 @@ async def cleanup_device_registry(
) )
def get_discovery_hash(discovery_data: dict) -> tuple[str, str]: def get_discovery_hash(discovery_data: DiscoveryInfoType) -> tuple[str, str]:
"""Get the discovery hash from the discovery data.""" """Get the discovery hash from the discovery data."""
return discovery_data[ATTR_DISCOVERY_HASH] return discovery_data[ATTR_DISCOVERY_HASH]
def send_discovery_done(hass: HomeAssistant, discovery_data: dict) -> None: def send_discovery_done(hass: HomeAssistant, discovery_data: DiscoveryInfoType) -> None:
"""Acknowledge a discovery message has been handled.""" """Acknowledge a discovery message has been handled."""
discovery_hash = get_discovery_hash(discovery_data) discovery_hash = get_discovery_hash(discovery_data)
async_dispatcher_send(hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None) async_dispatcher_send(hass, MQTT_DISCOVERY_DONE.format(discovery_hash), None)
@ -641,7 +651,7 @@ def send_discovery_done(hass: HomeAssistant, discovery_data: dict) -> None:
def stop_discovery_updates( def stop_discovery_updates(
hass: HomeAssistant, hass: HomeAssistant,
discovery_data: dict, discovery_data: DiscoveryInfoType,
remove_discovery_updated: Callable[[], None] | None = None, remove_discovery_updated: Callable[[], None] | None = None,
) -> None: ) -> None:
"""Stop discovery updates of being sent.""" """Stop discovery updates of being sent."""
@ -660,7 +670,7 @@ async def async_remove_discovery_payload(hass: HomeAssistant, discovery_data: di
async def async_clear_discovery_topic_if_entity_removed( async def async_clear_discovery_topic_if_entity_removed(
hass: HomeAssistant, hass: HomeAssistant,
discovery_data: dict[str, Any], discovery_data: DiscoveryInfoType,
event: Event, event: Event,
) -> None: ) -> None:
"""Clear the discovery topic if the entity is removed.""" """Clear the discovery topic if the entity is removed."""
@ -675,7 +685,7 @@ class MqttDiscoveryDeviceUpdate:
def __init__( def __init__(
self, self,
hass: HomeAssistant, hass: HomeAssistant,
discovery_data: dict, discovery_data: DiscoveryInfoType,
device_id: str | None, device_id: str | None,
config_entry: ConfigEntry, config_entry: ConfigEntry,
log_name: str, log_name: str,
@ -718,7 +728,7 @@ class MqttDiscoveryDeviceUpdate:
async def async_discovery_update( async def async_discovery_update(
self, self,
discovery_payload: DiscoveryInfoType | None, discovery_payload: MQTTDiscoveryPayload,
) -> None: ) -> None:
"""Handle discovery update.""" """Handle discovery update."""
discovery_hash = get_discovery_hash(self._discovery_data) discovery_hash = get_discovery_hash(self._discovery_data)
@ -789,7 +799,7 @@ class MqttDiscoveryDeviceUpdate:
self.hass, self._device_id, self._config_entry_id self.hass, self._device_id, self._config_entry_id
) )
async def async_update(self, discovery_data: dict) -> None: async def async_update(self, discovery_data: MQTTDiscoveryPayload) -> None:
"""Handle the update of platform specific parts, extend to the platform.""" """Handle the update of platform specific parts, extend to the platform."""
@abstractmethod @abstractmethod
@ -803,8 +813,9 @@ class MqttDiscoveryUpdate(Entity):
def __init__( def __init__(
self, self,
hass: HomeAssistant, hass: HomeAssistant,
discovery_data: dict | None, discovery_data: DiscoveryInfoType | None,
discovery_update: Callable | None = None, discovery_update: Callable[[MQTTDiscoveryPayload], Coroutine[Any, Any, None]]
| None = None,
) -> None: ) -> None:
"""Initialize the discovery update mixin.""" """Initialize the discovery update mixin."""
self._discovery_data = discovery_data self._discovery_data = discovery_data
@ -823,11 +834,13 @@ class MqttDiscoveryUpdate(Entity):
"""Subscribe to discovery updates.""" """Subscribe to discovery updates."""
await super().async_added_to_hass() await super().async_added_to_hass()
self._removed_from_hass = False self._removed_from_hass = False
discovery_hash = ( discovery_hash: tuple[str, str] | None = (
self._discovery_data[ATTR_DISCOVERY_HASH] if self._discovery_data else None self._discovery_data[ATTR_DISCOVERY_HASH] if self._discovery_data else None
) )
async def _async_remove_state_and_registry_entry(self) -> None: async def _async_remove_state_and_registry_entry(
self: MqttDiscoveryUpdate,
) -> None:
"""Remove entity's state and entity registry entry. """Remove entity's state and entity registry entry.
Remove entity from entity registry if it is registered, this also removes the state. Remove entity from entity registry if it is registered, this also removes the state.
@ -842,13 +855,15 @@ class MqttDiscoveryUpdate(Entity):
else: else:
await self.async_remove(force_remove=True) await self.async_remove(force_remove=True)
async def discovery_callback(payload): async def discovery_callback(payload: MQTTDiscoveryPayload) -> None:
"""Handle discovery update.""" """Handle discovery update."""
_LOGGER.info( _LOGGER.info(
"Got update for entity with hash: %s '%s'", "Got update for entity with hash: %s '%s'",
discovery_hash, discovery_hash,
payload, payload,
) )
assert self._discovery_data
old_payload: DiscoveryInfoType
old_payload = self._discovery_data[ATTR_DISCOVERY_PAYLOAD] old_payload = self._discovery_data[ATTR_DISCOVERY_PAYLOAD]
debug_info.update_entity_discovery_data(self.hass, payload, self.entity_id) debug_info.update_entity_discovery_data(self.hass, payload, self.entity_id)
if not payload: if not payload:
@ -923,39 +938,43 @@ class MqttDiscoveryUpdate(Entity):
self._removed_from_hass = True self._removed_from_hass = True
def device_info_from_config(config) -> DeviceInfo | None: def device_info_from_specifications(
specifications: dict[str, Any] | None
) -> DeviceInfo | None:
"""Return a device description for device registry.""" """Return a device description for device registry."""
if not config: if not specifications:
return None return None
info = DeviceInfo( info = DeviceInfo(
identifiers={(DOMAIN, id_) for id_ in config[CONF_IDENTIFIERS]}, identifiers={(DOMAIN, id_) for id_ in specifications[CONF_IDENTIFIERS]},
connections={(conn_[0], conn_[1]) for conn_ in config[CONF_CONNECTIONS]}, connections={
(conn_[0], conn_[1]) for conn_ in specifications[CONF_CONNECTIONS]
},
) )
if CONF_MANUFACTURER in config: if CONF_MANUFACTURER in specifications:
info[ATTR_MANUFACTURER] = config[CONF_MANUFACTURER] info[ATTR_MANUFACTURER] = specifications[CONF_MANUFACTURER]
if CONF_MODEL in config: if CONF_MODEL in specifications:
info[ATTR_MODEL] = config[CONF_MODEL] info[ATTR_MODEL] = specifications[CONF_MODEL]
if CONF_NAME in config: if CONF_NAME in specifications:
info[ATTR_NAME] = config[CONF_NAME] info[ATTR_NAME] = specifications[CONF_NAME]
if CONF_HW_VERSION in config: if CONF_HW_VERSION in specifications:
info[ATTR_HW_VERSION] = config[CONF_HW_VERSION] info[ATTR_HW_VERSION] = specifications[CONF_HW_VERSION]
if CONF_SW_VERSION in config: if CONF_SW_VERSION in specifications:
info[ATTR_SW_VERSION] = config[CONF_SW_VERSION] info[ATTR_SW_VERSION] = specifications[CONF_SW_VERSION]
if CONF_VIA_DEVICE in config: if CONF_VIA_DEVICE in specifications:
info[ATTR_VIA_DEVICE] = (DOMAIN, config[CONF_VIA_DEVICE]) info[ATTR_VIA_DEVICE] = (DOMAIN, specifications[CONF_VIA_DEVICE])
if CONF_SUGGESTED_AREA in config: if CONF_SUGGESTED_AREA in specifications:
info[ATTR_SUGGESTED_AREA] = config[CONF_SUGGESTED_AREA] info[ATTR_SUGGESTED_AREA] = specifications[CONF_SUGGESTED_AREA]
if CONF_CONFIGURATION_URL in config: if CONF_CONFIGURATION_URL in specifications:
info[ATTR_CONFIGURATION_URL] = config[CONF_CONFIGURATION_URL] info[ATTR_CONFIGURATION_URL] = specifications[CONF_CONFIGURATION_URL]
return info return info
@ -963,19 +982,21 @@ def device_info_from_config(config) -> DeviceInfo | None:
class MqttEntityDeviceInfo(Entity): class MqttEntityDeviceInfo(Entity):
"""Mixin used for mqtt platforms that support the device registry.""" """Mixin used for mqtt platforms that support the device registry."""
def __init__(self, device_config: ConfigType | None, config_entry=None) -> None: def __init__(
self, specifications: dict[str, Any] | None, config_entry: ConfigEntry
) -> None:
"""Initialize the device mixin.""" """Initialize the device mixin."""
self._device_config = device_config self._device_specifications = specifications
self._config_entry = config_entry self._config_entry = config_entry
def device_info_discovery_update(self, config: dict): def device_info_discovery_update(self, config: DiscoveryInfoType) -> None:
"""Handle updated discovery message.""" """Handle updated discovery message."""
self._device_config = config.get(CONF_DEVICE) self._device_specifications = config.get(CONF_DEVICE)
device_registry = dr.async_get(self.hass) device_registry = dr.async_get(self.hass)
config_entry_id = self._config_entry.entry_id config_entry_id = self._config_entry.entry_id
device_info = self.device_info device_info = self.device_info
if config_entry_id is not None and device_info is not None: if device_info is not None:
device_registry.async_get_or_create( device_registry.async_get_or_create(
config_entry_id=config_entry_id, **device_info config_entry_id=config_entry_id, **device_info
) )
@ -983,7 +1004,7 @@ class MqttEntityDeviceInfo(Entity):
@property @property
def device_info(self) -> DeviceInfo | None: def device_info(self) -> DeviceInfo | None:
"""Return a device description for device registry.""" """Return a device description for device registry."""
return device_info_from_config(self._device_config) return device_info_from_specifications(self._device_specifications)
class MqttEntity( class MqttEntity(
@ -997,12 +1018,18 @@ class MqttEntity(
_attr_should_poll = False _attr_should_poll = False
_entity_id_format: str _entity_id_format: str
def __init__(self, hass, config, config_entry, discovery_data): def __init__(
self,
hass: HomeAssistant,
config: ConfigType,
config_entry: ConfigEntry,
discovery_data: DiscoveryInfoType | None,
) -> None:
"""Init the MQTT Entity.""" """Init the MQTT Entity."""
self.hass = hass self.hass = hass
self._config = config self._config: ConfigType = config
self._unique_id = config.get(CONF_UNIQUE_ID) self._unique_id: str | None = config.get(CONF_UNIQUE_ID)
self._sub_state = None self._sub_state: dict[str, EntitySubscription] = {}
# Load config # Load config
self._setup_from_config(self._config) self._setup_from_config(self._config)
@ -1016,14 +1043,14 @@ class MqttEntity(
MqttDiscoveryUpdate.__init__(self, hass, discovery_data, self.discovery_update) MqttDiscoveryUpdate.__init__(self, hass, discovery_data, self.discovery_update)
MqttEntityDeviceInfo.__init__(self, config.get(CONF_DEVICE), config_entry) MqttEntityDeviceInfo.__init__(self, config.get(CONF_DEVICE), config_entry)
def _init_entity_id(self): def _init_entity_id(self) -> None:
"""Set entity_id from object_id if defined in config.""" """Set entity_id from object_id if defined in config."""
init_entity_id_from_config( init_entity_id_from_config(
self.hass, self, self._config, self._entity_id_format self.hass, self, self._config, self._entity_id_format
) )
@final @final
async def async_added_to_hass(self): async def async_added_to_hass(self) -> None:
"""Subscribe to MQTT events.""" """Subscribe to MQTT events."""
await super().async_added_to_hass() await super().async_added_to_hass()
self._prepare_subscribe_topics() self._prepare_subscribe_topics()
@ -1032,15 +1059,15 @@ class MqttEntity(
if self._discovery_data is not None: if self._discovery_data is not None:
send_discovery_done(self.hass, self._discovery_data) send_discovery_done(self.hass, self._discovery_data)
async def mqtt_async_added_to_hass(self): async def mqtt_async_added_to_hass(self) -> None:
"""Call before the discovery message is acknowledged. """Call before the discovery message is acknowledged.
To be extended by subclasses. To be extended by subclasses.
""" """
async def discovery_update(self, discovery_payload): async def discovery_update(self, discovery_payload: MQTTDiscoveryPayload) -> None:
"""Handle updated discovery message.""" """Handle updated discovery message."""
config = self.config_schema()(discovery_payload) config: DiscoveryInfoType = self.config_schema()(discovery_payload)
self._config = config self._config = config
self._setup_from_config(self._config) self._setup_from_config(self._config)
@ -1056,7 +1083,7 @@ class MqttEntity(
await self._subscribe_topics() await self._subscribe_topics()
self.async_write_ha_state() self.async_write_ha_state()
async def async_will_remove_from_hass(self): async def async_will_remove_from_hass(self) -> None:
"""Unsubscribe when removed.""" """Unsubscribe when removed."""
self._sub_state = subscription.async_unsubscribe_topics( self._sub_state = subscription.async_unsubscribe_topics(
self.hass, self._sub_state self.hass, self._sub_state
@ -1073,7 +1100,7 @@ class MqttEntity(
qos: int = 0, qos: int = 0,
retain: bool = False, retain: bool = False,
encoding: str = DEFAULT_ENCODING, encoding: str = DEFAULT_ENCODING,
): ) -> None:
"""Publish message to an MQTT topic.""" """Publish message to an MQTT topic."""
log_message(self.hass, self.entity_id, topic, payload, qos, retain) log_message(self.hass, self.entity_id, topic, payload, qos, retain)
await async_publish( await async_publish(
@ -1087,18 +1114,18 @@ class MqttEntity(
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def config_schema(): def config_schema() -> vol.Schema:
"""Return the config schema.""" """Return the config schema."""
def _setup_from_config(self, config): def _setup_from_config(self, config: ConfigType) -> None:
"""(Re)Setup the entity.""" """(Re)Setup the entity."""
@abstractmethod @abstractmethod
def _prepare_subscribe_topics(self): def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
@abstractmethod @abstractmethod
async def _subscribe_topics(self): async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
@property @property
@ -1112,17 +1139,17 @@ class MqttEntity(
return self._config.get(CONF_ENTITY_CATEGORY) return self._config.get(CONF_ENTITY_CATEGORY)
@property @property
def icon(self): def icon(self) -> str | None:
"""Return icon of the entity if any.""" """Return icon of the entity if any."""
return self._config.get(CONF_ICON) return self._config.get(CONF_ICON)
@property @property
def name(self): def name(self) -> str | None:
"""Return the name of the device if any.""" """Return the name of the device if any."""
return self._config.get(CONF_NAME) return self._config.get(CONF_NAME)
@property @property
def unique_id(self): def unique_id(self) -> str | None:
"""Return a unique ID.""" """Return a unique ID."""
return self._unique_id return self._unique_id
@ -1136,10 +1163,10 @@ def update_device(
if CONF_DEVICE not in config: if CONF_DEVICE not in config:
return None return None
device = None device: DeviceEntry | None = None
device_registry = dr.async_get(hass) device_registry = dr.async_get(hass)
config_entry_id = config_entry.entry_id config_entry_id = config_entry.entry_id
device_info = device_info_from_config(config[CONF_DEVICE]) device_info = device_info_from_specifications(config[CONF_DEVICE])
if config_entry_id is not None and device_info is not None: if config_entry_id is not None and device_info is not None:
update_device_info = cast(dict, device_info) update_device_info = cast(dict, device_info)
@ -1154,7 +1181,7 @@ def async_removed_from_device(
hass: HomeAssistant, event: Event, mqtt_device_id: str, config_entry_id: str hass: HomeAssistant, event: Event, mqtt_device_id: str, config_entry_id: str
) -> bool: ) -> bool:
"""Check if the passed event indicates MQTT was removed from a device.""" """Check if the passed event indicates MQTT was removed from a device."""
device_id = event.data["device_id"] device_id: str = event.data["device_id"]
if event.data["action"] not in ("remove", "update"): if event.data["action"] not in ("remove", "update"):
return False return False

View File

@ -23,7 +23,7 @@ if TYPE_CHECKING:
from .client import MQTT, Subscription from .client import MQTT, Subscription
from .debug_info import TimestampedPublishMessage from .debug_info import TimestampedPublishMessage
from .device_trigger import Trigger from .device_trigger import Trigger
from .discovery import MQTTConfig from .discovery import MQTTDiscoveryPayload
from .tag import MQTTTagScanner from .tag import MQTTTagScanner
_SENTINEL = object() _SENTINEL = object()
@ -86,7 +86,7 @@ class TriggerDebugInfo(TypedDict):
class PendingDiscovered(TypedDict): class PendingDiscovered(TypedDict):
"""Pending discovered items.""" """Pending discovered items."""
pending: deque[MQTTConfig] pending: deque[MQTTDiscoveryPayload]
unsub: CALLBACK_TYPE unsub: CALLBACK_TYPE

View File

@ -98,8 +98,6 @@ class MqttUpdate(MqttEntity, UpdateEntity, RestoreEntity):
) -> None: ) -> None:
"""Initialize the MQTT update.""" """Initialize the MQTT update."""
self._config = config self._config = config
self._sub_state = None
self._attr_device_class = self._config.get(CONF_DEVICE_CLASS) self._attr_device_class = self._config.get(CONF_DEVICE_CLASS)
UpdateEntity.__init__(self) UpdateEntity.__init__(self)