Improve MQTT type hints part 6 (#81001)

* Improve typing siren

* Improve typing switch

* Set siren type hints at class level

* Set switch type hints at class level

* Follow up comment

* Improve hints on siren templates

* Another cleanup for siren

* Follow up comment

* Follow up comment
This commit is contained in:
Jan Bouwhuis 2022-11-08 13:11:45 +01:00 committed by GitHub
parent 47dba6f6bc
commit 4293c88fb6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 83 additions and 62 deletions

View File

@ -1,10 +1,11 @@
"""Support for MQTT sirens.""" """Support for MQTT sirens."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable
import copy import copy
import functools import functools
import logging import logging
from typing import Any from typing import Any, cast
import voluptuous as vol import voluptuous as vol
@ -17,6 +18,7 @@ from homeassistant.components.siren import (
TURN_ON_SCHEMA, TURN_ON_SCHEMA,
SirenEntity, SirenEntity,
SirenEntityFeature, SirenEntityFeature,
SirenTurnOnServiceParameters,
process_turn_on_params, process_turn_on_params,
) )
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
@ -30,7 +32,8 @@ from homeassistant.core import HomeAssistant, callback
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.json import JSON_DECODE_EXCEPTIONS, json_dumps, json_loads from homeassistant.helpers.json import JSON_DECODE_EXCEPTIONS, json_dumps, json_loads
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.helpers.template import Template
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType, TemplateVarsType
from . import subscription from . import subscription
from .config import MQTT_RW_SCHEMA from .config import MQTT_RW_SCHEMA
@ -53,7 +56,13 @@ from .mixins import (
async_setup_platform_helper, async_setup_platform_helper,
warn_for_legacy_schema, warn_for_legacy_schema,
) )
from .models import MqttCommandTemplate, MqttValueTemplate from .models import (
MqttCommandTemplate,
MqttValueTemplate,
PublishPayloadType,
ReceiveMessage,
ReceivePayloadType,
)
from .util import get_mqtt_data from .util import get_mqtt_data
DEFAULT_NAME = "MQTT Siren" DEFAULT_NAME = "MQTT Siren"
@ -150,8 +159,8 @@ async def _async_setup_entity(
hass: HomeAssistant, hass: HomeAssistant,
async_add_entities: AddEntitiesCallback, async_add_entities: AddEntitiesCallback,
config: ConfigType, config: ConfigType,
config_entry: ConfigEntry | None = None, config_entry: ConfigEntry,
discovery_data: dict | None = None, discovery_data: DiscoveryInfoType | None = None,
) -> None: ) -> None:
"""Set up the MQTT siren.""" """Set up the MQTT siren."""
async_add_entities([MqttSiren(hass, config, config_entry, discovery_data)]) async_add_entities([MqttSiren(hass, config, config_entry, discovery_data)])
@ -162,29 +171,32 @@ class MqttSiren(MqttEntity, SirenEntity):
_entity_id_format = ENTITY_ID_FORMAT _entity_id_format = ENTITY_ID_FORMAT
_attributes_extra_blocked = MQTT_SIREN_ATTRIBUTES_BLOCKED _attributes_extra_blocked = MQTT_SIREN_ATTRIBUTES_BLOCKED
_attr_supported_features: int
def __init__(self, hass, config, config_entry, discovery_data): _command_templates: dict[
str, Callable[[PublishPayloadType, TemplateVarsType], PublishPayloadType] | None
]
_value_template: Callable[[ReceivePayloadType], ReceivePayloadType]
_state_on: str
_state_off: str
_optimistic: bool
def __init__(
self,
hass: HomeAssistant,
config: ConfigType,
config_entry: ConfigEntry,
discovery_data: DiscoveryInfoType | None,
) -> None:
"""Initialize the MQTT siren.""" """Initialize the MQTT siren."""
self._attr_name = config[CONF_NAME] MqttEntity.__init__(self, hass, config, config_entry, discovery_data)
self._attr_should_poll = False
self._supported_features = SUPPORTED_BASE
self._attr_is_on = None
self._state_on = None
self._state_off = None
self._optimistic = None
self._attr_extra_state_attributes: dict[str, Any] = {}
self.target = None
super().__init__(hass, config, config_entry, discovery_data)
@staticmethod @staticmethod
def config_schema(): def config_schema() -> vol.Schema:
"""Return the config schema.""" """Return the config schema."""
return DISCOVERY_SCHEMA return DISCOVERY_SCHEMA
def _setup_from_config(self, config): def _setup_from_config(self, config: ConfigType) -> None:
"""(Re)Setup the entity.""" """(Re)Setup the entity."""
state_on = config.get(CONF_STATE_ON) state_on = config.get(CONF_STATE_ON)
@ -193,25 +205,29 @@ class MqttSiren(MqttEntity, SirenEntity):
state_off = config.get(CONF_STATE_OFF) state_off = config.get(CONF_STATE_OFF)
self._state_off = state_off if state_off else config[CONF_PAYLOAD_OFF] self._state_off = state_off if state_off else config[CONF_PAYLOAD_OFF]
self._attr_extra_state_attributes = {}
_supported_features: int = SUPPORTED_BASE
if config[CONF_SUPPORT_DURATION]: if config[CONF_SUPPORT_DURATION]:
self._supported_features |= SirenEntityFeature.DURATION _supported_features |= SirenEntityFeature.DURATION
self._attr_extra_state_attributes[ATTR_DURATION] = None self._attr_extra_state_attributes[ATTR_DURATION] = None
if config.get(CONF_AVAILABLE_TONES): if config.get(CONF_AVAILABLE_TONES):
self._supported_features |= SirenEntityFeature.TONES _supported_features |= SirenEntityFeature.TONES
self._attr_available_tones = config[CONF_AVAILABLE_TONES] self._attr_available_tones = config[CONF_AVAILABLE_TONES]
self._attr_extra_state_attributes[ATTR_TONE] = None self._attr_extra_state_attributes[ATTR_TONE] = None
if config[CONF_SUPPORT_VOLUME_SET]: if config[CONF_SUPPORT_VOLUME_SET]:
self._supported_features |= SirenEntityFeature.VOLUME_SET _supported_features |= SirenEntityFeature.VOLUME_SET
self._attr_extra_state_attributes[ATTR_VOLUME_LEVEL] = None self._attr_extra_state_attributes[ATTR_VOLUME_LEVEL] = None
self._attr_supported_features = _supported_features
self._optimistic = config[CONF_OPTIMISTIC] or CONF_STATE_TOPIC not in config self._optimistic = config[CONF_OPTIMISTIC] or CONF_STATE_TOPIC not in config
self._attr_is_on = False if self._optimistic else None self._attr_is_on = False if self._optimistic else None
command_template = config.get(CONF_COMMAND_TEMPLATE) command_template: Template | None = config.get(CONF_COMMAND_TEMPLATE)
command_off_template = config.get(CONF_COMMAND_OFF_TEMPLATE) or config.get( command_off_template: Template | None = (
CONF_COMMAND_TEMPLATE config.get(CONF_COMMAND_OFF_TEMPLATE) or command_template
) )
self._command_templates = { self._command_templates = {
CONF_COMMAND_TEMPLATE: MqttCommandTemplate( CONF_COMMAND_TEMPLATE: MqttCommandTemplate(
@ -230,12 +246,12 @@ class MqttSiren(MqttEntity, SirenEntity):
entity=self, entity=self,
).async_render_with_possible_json_value ).async_render_with_possible_json_value
def _prepare_subscribe_topics(self): def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
@callback @callback
@log_messages(self.hass, self.entity_id) @log_messages(self.hass, self.entity_id)
def state_message_received(msg): def state_message_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT state messages.""" """Handle new MQTT state messages."""
payload = self._value_template(msg.payload) payload = self._value_template(msg.payload)
if not payload or payload == PAYLOAD_EMPTY_JSON: if not payload or payload == PAYLOAD_EMPTY_JSON:
@ -245,7 +261,7 @@ class MqttSiren(MqttEntity, SirenEntity):
msg.topic, msg.topic,
) )
return return
json_payload = {} json_payload: dict[str, Any] = {}
if payload in [self._state_on, self._state_off, PAYLOAD_NONE]: if payload in [self._state_on, self._state_off, PAYLOAD_NONE]:
json_payload = {STATE: payload} json_payload = {STATE: payload}
else: else:
@ -275,7 +291,8 @@ class MqttSiren(MqttEntity, SirenEntity):
if json_payload: if json_payload:
# process attributes # process attributes
try: try:
vol.All(TURN_ON_SCHEMA)(json_payload) params: SirenTurnOnServiceParameters
params = vol.All(TURN_ON_SCHEMA)(json_payload)
except vol.MultipleInvalid as invalid_siren_parameters: except vol.MultipleInvalid as invalid_siren_parameters:
_LOGGER.warning( _LOGGER.warning(
"Unable to update siren state attributes from payload '%s': %s", "Unable to update siren state attributes from payload '%s': %s",
@ -283,7 +300,7 @@ class MqttSiren(MqttEntity, SirenEntity):
invalid_siren_parameters, invalid_siren_parameters,
) )
return return
self._update(process_turn_on_params(self, json_payload)) self._update(process_turn_on_params(self, params))
get_mqtt_data(self.hass).state_write_requests.write_state_request(self) get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
if self._config.get(CONF_STATE_TOPIC) is None: if self._config.get(CONF_STATE_TOPIC) is None:
@ -303,7 +320,7 @@ class MqttSiren(MqttEntity, SirenEntity):
}, },
) )
async def _subscribe_topics(self): async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) await subscription.async_subscribe_topics(self.hass, self._sub_state)
@ -322,11 +339,6 @@ class MqttSiren(MqttEntity, SirenEntity):
attributes.update(self._attr_extra_state_attributes) attributes.update(self._attr_extra_state_attributes)
return attributes return attributes
@property
def supported_features(self) -> int:
"""Flag supported features."""
return self._supported_features
async def _async_publish( async def _async_publish(
self, self,
topic: str, topic: str,
@ -335,15 +347,14 @@ class MqttSiren(MqttEntity, SirenEntity):
variables: dict[str, Any] | None = None, variables: dict[str, Any] | None = None,
) -> None: ) -> None:
"""Publish MQTT payload with optional command template.""" """Publish MQTT payload with optional command template."""
template_variables = {STATE: value} template_variables: dict[str, Any] = {STATE: value}
if variables is not None: if variables is not None:
template_variables.update(variables) template_variables.update(variables)
payload = ( if command_template := self._command_templates[template]:
self._command_templates[template](value, template_variables) payload = command_template(value, template_variables)
if self._command_templates[template] else:
else json_dumps(template_variables) payload = json_dumps(template_variables)
) if payload and str(payload) != PAYLOAD_NONE:
if payload and payload not in PAYLOAD_NONE:
await self.async_publish( await self.async_publish(
self._config[topic], self._config[topic],
payload, payload,
@ -367,7 +378,7 @@ class MqttSiren(MqttEntity, SirenEntity):
# Optimistically assume that siren has changed state. # Optimistically assume that siren has changed state.
_LOGGER.debug("Writing state attributes %s", kwargs) _LOGGER.debug("Writing state attributes %s", kwargs)
self._attr_is_on = True self._attr_is_on = True
self._update(kwargs) self._update(cast(SirenTurnOnServiceParameters, kwargs))
self.async_write_ha_state() self.async_write_ha_state()
async def async_turn_off(self, **kwargs: Any) -> None: async def async_turn_off(self, **kwargs: Any) -> None:
@ -386,8 +397,8 @@ class MqttSiren(MqttEntity, SirenEntity):
self._attr_is_on = False self._attr_is_on = False
self.async_write_ha_state() self.async_write_ha_state()
def _update(self, data: dict[str, Any]) -> None: def _update(self, data: SirenTurnOnServiceParameters) -> None:
"""Update the extra siren state attributes.""" """Update the extra siren state attributes."""
for attribute, support in SUPPORTED_ATTRIBUTES.items(): for attribute, support in SUPPORTED_ATTRIBUTES.items():
if self._supported_features & support and attribute in data: if self._attr_supported_features & support and attribute in data:
self._attr_extra_state_attributes[attribute] = data[attribute] self._attr_extra_state_attributes[attribute] = data[attribute] # type: ignore[literal-required]

View File

@ -1,6 +1,7 @@
"""Support for MQTT switches.""" """Support for MQTT switches."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable
import functools import functools
from typing import Any from typing import Any
@ -22,6 +23,7 @@ from homeassistant.core import HomeAssistant, callback
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.restore_state import RestoreEntity from homeassistant.helpers.restore_state import RestoreEntity
from homeassistant.helpers.service_info.mqtt import ReceivePayloadType
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
from . import subscription from . import subscription
@ -42,7 +44,7 @@ from .mixins import (
async_setup_platform_helper, async_setup_platform_helper,
warn_for_legacy_schema, warn_for_legacy_schema,
) )
from .models import MqttValueTemplate from .models import MqttValueTemplate, ReceiveMessage
from .util import get_mqtt_data from .util import get_mqtt_data
DEFAULT_NAME = "MQTT Switch" DEFAULT_NAME = "MQTT Switch"
@ -107,8 +109,8 @@ async def _async_setup_entity(
hass: HomeAssistant, hass: HomeAssistant,
async_add_entities: AddEntitiesCallback, async_add_entities: AddEntitiesCallback,
config: ConfigType, config: ConfigType,
config_entry: ConfigEntry | None = None, config_entry: ConfigEntry,
discovery_data: dict | None = None, discovery_data: DiscoveryInfoType | None = None,
) -> None: ) -> None:
"""Set up the MQTT switch.""" """Set up the MQTT switch."""
async_add_entities([MqttSwitch(hass, config, config_entry, discovery_data)]) async_add_entities([MqttSwitch(hass, config, config_entry, discovery_data)])
@ -119,16 +121,24 @@ class MqttSwitch(MqttEntity, SwitchEntity, RestoreEntity):
_entity_id_format = switch.ENTITY_ID_FORMAT _entity_id_format = switch.ENTITY_ID_FORMAT
def __init__(self, hass, config, config_entry, discovery_data): _optimistic: bool
_state_on: str
_state_off: str
_value_template: Callable[[ReceivePayloadType], ReceivePayloadType]
def __init__(
self,
hass: HomeAssistant,
config: ConfigType,
config_entry: ConfigEntry,
discovery_data: DiscoveryInfoType | None,
) -> None:
"""Initialize the MQTT switch.""" """Initialize the MQTT switch."""
self._state_on = None
self._state_off = None
self._optimistic = None
MqttEntity.__init__(self, hass, config, config_entry, discovery_data) MqttEntity.__init__(self, hass, config, config_entry, discovery_data)
@staticmethod @staticmethod
def config_schema(): def config_schema() -> vol.Schema:
"""Return the config schema.""" """Return the config schema."""
return DISCOVERY_SCHEMA return DISCOVERY_SCHEMA
@ -136,10 +146,10 @@ class MqttSwitch(MqttEntity, SwitchEntity, RestoreEntity):
"""(Re)Setup the entity.""" """(Re)Setup the entity."""
self._attr_device_class = config.get(CONF_DEVICE_CLASS) self._attr_device_class = config.get(CONF_DEVICE_CLASS)
state_on = config.get(CONF_STATE_ON) state_on: str | None = config.get(CONF_STATE_ON)
self._state_on = state_on if state_on else config[CONF_PAYLOAD_ON] self._state_on = state_on if state_on else config[CONF_PAYLOAD_ON]
state_off = config.get(CONF_STATE_OFF) state_off: str | None = config.get(CONF_STATE_OFF)
self._state_off = state_off if state_off else config[CONF_PAYLOAD_OFF] self._state_off = state_off if state_off else config[CONF_PAYLOAD_OFF]
self._optimistic = ( self._optimistic = (
@ -150,12 +160,12 @@ class MqttSwitch(MqttEntity, SwitchEntity, RestoreEntity):
self._config.get(CONF_VALUE_TEMPLATE), entity=self self._config.get(CONF_VALUE_TEMPLATE), entity=self
).async_render_with_possible_json_value ).async_render_with_possible_json_value
def _prepare_subscribe_topics(self): def _prepare_subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
@callback @callback
@log_messages(self.hass, self.entity_id) @log_messages(self.hass, self.entity_id)
def state_message_received(msg): def state_message_received(msg: ReceiveMessage) -> None:
"""Handle new MQTT state messages.""" """Handle new MQTT state messages."""
payload = self._value_template(msg.payload) payload = self._value_template(msg.payload)
if payload == self._state_on: if payload == self._state_on:
@ -184,7 +194,7 @@ class MqttSwitch(MqttEntity, SwitchEntity, RestoreEntity):
}, },
) )
async def _subscribe_topics(self): async def _subscribe_topics(self) -> None:
"""(Re)Subscribe to topics.""" """(Re)Subscribe to topics."""
await subscription.async_subscribe_topics(self.hass, self._sub_state) await subscription.async_subscribe_topics(self.hass, self._sub_state)