From 4293c88fb664f9fce73974da4a279decd450dd07 Mon Sep 17 00:00:00 2001 From: Jan Bouwhuis Date: Tue, 8 Nov 2022 13:11:45 +0100 Subject: [PATCH] Improve MQTT type hints part 6 (#81001) * Improve typing siren * Improve typing switch * Set siren type hints at class level * Set switch type hints at class level * Follow up comment * Improve hints on siren templates * Another cleanup for siren * Follow up comment * Follow up comment --- homeassistant/components/mqtt/siren.py | 109 +++++++++++++----------- homeassistant/components/mqtt/switch.py | 36 +++++--- 2 files changed, 83 insertions(+), 62 deletions(-) diff --git a/homeassistant/components/mqtt/siren.py b/homeassistant/components/mqtt/siren.py index 2ab226e44c0..4a69977df45 100644 --- a/homeassistant/components/mqtt/siren.py +++ b/homeassistant/components/mqtt/siren.py @@ -1,10 +1,11 @@ """Support for MQTT sirens.""" from __future__ import annotations +from collections.abc import Callable import copy import functools import logging -from typing import Any +from typing import Any, cast import voluptuous as vol @@ -17,6 +18,7 @@ from homeassistant.components.siren import ( TURN_ON_SCHEMA, SirenEntity, SirenEntityFeature, + SirenTurnOnServiceParameters, process_turn_on_params, ) from homeassistant.config_entries import ConfigEntry @@ -30,7 +32,8 @@ from homeassistant.core import HomeAssistant, callback import homeassistant.helpers.config_validation as cv from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.json import JSON_DECODE_EXCEPTIONS, json_dumps, json_loads -from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType +from homeassistant.helpers.template import Template +from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType, TemplateVarsType from . import subscription from .config import MQTT_RW_SCHEMA @@ -53,7 +56,13 @@ from .mixins import ( async_setup_platform_helper, warn_for_legacy_schema, ) -from .models import MqttCommandTemplate, MqttValueTemplate +from .models import ( + MqttCommandTemplate, + MqttValueTemplate, + PublishPayloadType, + ReceiveMessage, + ReceivePayloadType, +) from .util import get_mqtt_data DEFAULT_NAME = "MQTT Siren" @@ -150,8 +159,8 @@ async def _async_setup_entity( hass: HomeAssistant, async_add_entities: AddEntitiesCallback, config: ConfigType, - config_entry: ConfigEntry | None = None, - discovery_data: dict | None = None, + config_entry: ConfigEntry, + discovery_data: DiscoveryInfoType | None = None, ) -> None: """Set up the MQTT siren.""" async_add_entities([MqttSiren(hass, config, config_entry, discovery_data)]) @@ -162,29 +171,32 @@ class MqttSiren(MqttEntity, SirenEntity): _entity_id_format = ENTITY_ID_FORMAT _attributes_extra_blocked = MQTT_SIREN_ATTRIBUTES_BLOCKED + _attr_supported_features: int - def __init__(self, hass, config, config_entry, discovery_data): + _command_templates: dict[ + str, Callable[[PublishPayloadType, TemplateVarsType], PublishPayloadType] | None + ] + _value_template: Callable[[ReceivePayloadType], ReceivePayloadType] + _state_on: str + _state_off: str + _optimistic: bool + + def __init__( + self, + hass: HomeAssistant, + config: ConfigType, + config_entry: ConfigEntry, + discovery_data: DiscoveryInfoType | None, + ) -> None: """Initialize the MQTT siren.""" - self._attr_name = config[CONF_NAME] - self._attr_should_poll = False - self._supported_features = SUPPORTED_BASE - self._attr_is_on = None - self._state_on = None - self._state_off = None - self._optimistic = None - - self._attr_extra_state_attributes: dict[str, Any] = {} - - self.target = None - - super().__init__(hass, config, config_entry, discovery_data) + MqttEntity.__init__(self, hass, config, config_entry, discovery_data) @staticmethod - def config_schema(): + def config_schema() -> vol.Schema: """Return the config schema.""" return DISCOVERY_SCHEMA - def _setup_from_config(self, config): + def _setup_from_config(self, config: ConfigType) -> None: """(Re)Setup the entity.""" state_on = config.get(CONF_STATE_ON) @@ -193,25 +205,29 @@ class MqttSiren(MqttEntity, SirenEntity): state_off = config.get(CONF_STATE_OFF) self._state_off = state_off if state_off else config[CONF_PAYLOAD_OFF] + self._attr_extra_state_attributes = {} + + _supported_features: int = SUPPORTED_BASE if config[CONF_SUPPORT_DURATION]: - self._supported_features |= SirenEntityFeature.DURATION + _supported_features |= SirenEntityFeature.DURATION self._attr_extra_state_attributes[ATTR_DURATION] = None if config.get(CONF_AVAILABLE_TONES): - self._supported_features |= SirenEntityFeature.TONES + _supported_features |= SirenEntityFeature.TONES self._attr_available_tones = config[CONF_AVAILABLE_TONES] self._attr_extra_state_attributes[ATTR_TONE] = None if config[CONF_SUPPORT_VOLUME_SET]: - self._supported_features |= SirenEntityFeature.VOLUME_SET + _supported_features |= SirenEntityFeature.VOLUME_SET self._attr_extra_state_attributes[ATTR_VOLUME_LEVEL] = None + self._attr_supported_features = _supported_features self._optimistic = config[CONF_OPTIMISTIC] or CONF_STATE_TOPIC not in config self._attr_is_on = False if self._optimistic else None - command_template = config.get(CONF_COMMAND_TEMPLATE) - command_off_template = config.get(CONF_COMMAND_OFF_TEMPLATE) or config.get( - CONF_COMMAND_TEMPLATE + command_template: Template | None = config.get(CONF_COMMAND_TEMPLATE) + command_off_template: Template | None = ( + config.get(CONF_COMMAND_OFF_TEMPLATE) or command_template ) self._command_templates = { CONF_COMMAND_TEMPLATE: MqttCommandTemplate( @@ -230,12 +246,12 @@ class MqttSiren(MqttEntity, SirenEntity): entity=self, ).async_render_with_possible_json_value - def _prepare_subscribe_topics(self): + def _prepare_subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" @callback @log_messages(self.hass, self.entity_id) - def state_message_received(msg): + def state_message_received(msg: ReceiveMessage) -> None: """Handle new MQTT state messages.""" payload = self._value_template(msg.payload) if not payload or payload == PAYLOAD_EMPTY_JSON: @@ -245,7 +261,7 @@ class MqttSiren(MqttEntity, SirenEntity): msg.topic, ) return - json_payload = {} + json_payload: dict[str, Any] = {} if payload in [self._state_on, self._state_off, PAYLOAD_NONE]: json_payload = {STATE: payload} else: @@ -275,7 +291,8 @@ class MqttSiren(MqttEntity, SirenEntity): if json_payload: # process attributes try: - vol.All(TURN_ON_SCHEMA)(json_payload) + params: SirenTurnOnServiceParameters + params = vol.All(TURN_ON_SCHEMA)(json_payload) except vol.MultipleInvalid as invalid_siren_parameters: _LOGGER.warning( "Unable to update siren state attributes from payload '%s': %s", @@ -283,7 +300,7 @@ class MqttSiren(MqttEntity, SirenEntity): invalid_siren_parameters, ) return - self._update(process_turn_on_params(self, json_payload)) + self._update(process_turn_on_params(self, params)) get_mqtt_data(self.hass).state_write_requests.write_state_request(self) if self._config.get(CONF_STATE_TOPIC) is None: @@ -303,7 +320,7 @@ class MqttSiren(MqttEntity, SirenEntity): }, ) - async def _subscribe_topics(self): + async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" await subscription.async_subscribe_topics(self.hass, self._sub_state) @@ -322,11 +339,6 @@ class MqttSiren(MqttEntity, SirenEntity): attributes.update(self._attr_extra_state_attributes) return attributes - @property - def supported_features(self) -> int: - """Flag supported features.""" - return self._supported_features - async def _async_publish( self, topic: str, @@ -335,15 +347,14 @@ class MqttSiren(MqttEntity, SirenEntity): variables: dict[str, Any] | None = None, ) -> None: """Publish MQTT payload with optional command template.""" - template_variables = {STATE: value} + template_variables: dict[str, Any] = {STATE: value} if variables is not None: template_variables.update(variables) - payload = ( - self._command_templates[template](value, template_variables) - if self._command_templates[template] - else json_dumps(template_variables) - ) - if payload and payload not in PAYLOAD_NONE: + if command_template := self._command_templates[template]: + payload = command_template(value, template_variables) + else: + payload = json_dumps(template_variables) + if payload and str(payload) != PAYLOAD_NONE: await self.async_publish( self._config[topic], payload, @@ -367,7 +378,7 @@ class MqttSiren(MqttEntity, SirenEntity): # Optimistically assume that siren has changed state. _LOGGER.debug("Writing state attributes %s", kwargs) self._attr_is_on = True - self._update(kwargs) + self._update(cast(SirenTurnOnServiceParameters, kwargs)) self.async_write_ha_state() async def async_turn_off(self, **kwargs: Any) -> None: @@ -386,8 +397,8 @@ class MqttSiren(MqttEntity, SirenEntity): self._attr_is_on = False self.async_write_ha_state() - def _update(self, data: dict[str, Any]) -> None: + def _update(self, data: SirenTurnOnServiceParameters) -> None: """Update the extra siren state attributes.""" for attribute, support in SUPPORTED_ATTRIBUTES.items(): - if self._supported_features & support and attribute in data: - self._attr_extra_state_attributes[attribute] = data[attribute] + if self._attr_supported_features & support and attribute in data: + self._attr_extra_state_attributes[attribute] = data[attribute] # type: ignore[literal-required] diff --git a/homeassistant/components/mqtt/switch.py b/homeassistant/components/mqtt/switch.py index f2a40facd8b..a20603e2399 100644 --- a/homeassistant/components/mqtt/switch.py +++ b/homeassistant/components/mqtt/switch.py @@ -1,6 +1,7 @@ """Support for MQTT switches.""" from __future__ import annotations +from collections.abc import Callable import functools from typing import Any @@ -22,6 +23,7 @@ from homeassistant.core import HomeAssistant, callback import homeassistant.helpers.config_validation as cv from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.restore_state import RestoreEntity +from homeassistant.helpers.service_info.mqtt import ReceivePayloadType from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from . import subscription @@ -42,7 +44,7 @@ from .mixins import ( async_setup_platform_helper, warn_for_legacy_schema, ) -from .models import MqttValueTemplate +from .models import MqttValueTemplate, ReceiveMessage from .util import get_mqtt_data DEFAULT_NAME = "MQTT Switch" @@ -107,8 +109,8 @@ async def _async_setup_entity( hass: HomeAssistant, async_add_entities: AddEntitiesCallback, config: ConfigType, - config_entry: ConfigEntry | None = None, - discovery_data: dict | None = None, + config_entry: ConfigEntry, + discovery_data: DiscoveryInfoType | None = None, ) -> None: """Set up the MQTT switch.""" async_add_entities([MqttSwitch(hass, config, config_entry, discovery_data)]) @@ -119,16 +121,24 @@ class MqttSwitch(MqttEntity, SwitchEntity, RestoreEntity): _entity_id_format = switch.ENTITY_ID_FORMAT - def __init__(self, hass, config, config_entry, discovery_data): + _optimistic: bool + _state_on: str + _state_off: str + _value_template: Callable[[ReceivePayloadType], ReceivePayloadType] + + def __init__( + self, + hass: HomeAssistant, + config: ConfigType, + config_entry: ConfigEntry, + discovery_data: DiscoveryInfoType | None, + ) -> None: """Initialize the MQTT switch.""" - self._state_on = None - self._state_off = None - self._optimistic = None MqttEntity.__init__(self, hass, config, config_entry, discovery_data) @staticmethod - def config_schema(): + def config_schema() -> vol.Schema: """Return the config schema.""" return DISCOVERY_SCHEMA @@ -136,10 +146,10 @@ class MqttSwitch(MqttEntity, SwitchEntity, RestoreEntity): """(Re)Setup the entity.""" self._attr_device_class = config.get(CONF_DEVICE_CLASS) - state_on = config.get(CONF_STATE_ON) + state_on: str | None = config.get(CONF_STATE_ON) self._state_on = state_on if state_on else config[CONF_PAYLOAD_ON] - state_off = config.get(CONF_STATE_OFF) + state_off: str | None = config.get(CONF_STATE_OFF) self._state_off = state_off if state_off else config[CONF_PAYLOAD_OFF] self._optimistic = ( @@ -150,12 +160,12 @@ class MqttSwitch(MqttEntity, SwitchEntity, RestoreEntity): self._config.get(CONF_VALUE_TEMPLATE), entity=self ).async_render_with_possible_json_value - def _prepare_subscribe_topics(self): + def _prepare_subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" @callback @log_messages(self.hass, self.entity_id) - def state_message_received(msg): + def state_message_received(msg: ReceiveMessage) -> None: """Handle new MQTT state messages.""" payload = self._value_template(msg.payload) if payload == self._state_on: @@ -184,7 +194,7 @@ class MqttSwitch(MqttEntity, SwitchEntity, RestoreEntity): }, ) - async def _subscribe_topics(self): + async def _subscribe_topics(self) -> None: """(Re)Subscribe to topics.""" await subscription.async_subscribe_topics(self.hass, self._sub_state)