mirror of
https://github.com/home-assistant/core.git
synced 2025-07-21 12:17:07 +00:00
Improve MQTT type hints part 6 (#81001)
* Improve typing siren * Improve typing switch * Set siren type hints at class level * Set switch type hints at class level * Follow up comment * Improve hints on siren templates * Another cleanup for siren * Follow up comment * Follow up comment
This commit is contained in:
parent
47dba6f6bc
commit
4293c88fb6
@ -1,10 +1,11 @@
|
|||||||
"""Support for MQTT sirens."""
|
"""Support for MQTT sirens."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
import copy
|
import copy
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any, cast
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
@ -17,6 +18,7 @@ from homeassistant.components.siren import (
|
|||||||
TURN_ON_SCHEMA,
|
TURN_ON_SCHEMA,
|
||||||
SirenEntity,
|
SirenEntity,
|
||||||
SirenEntityFeature,
|
SirenEntityFeature,
|
||||||
|
SirenTurnOnServiceParameters,
|
||||||
process_turn_on_params,
|
process_turn_on_params,
|
||||||
)
|
)
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
@ -30,7 +32,8 @@ from homeassistant.core import HomeAssistant, callback
|
|||||||
import homeassistant.helpers.config_validation as cv
|
import homeassistant.helpers.config_validation as cv
|
||||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
from homeassistant.helpers.json import JSON_DECODE_EXCEPTIONS, json_dumps, json_loads
|
from homeassistant.helpers.json import JSON_DECODE_EXCEPTIONS, json_dumps, json_loads
|
||||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
from homeassistant.helpers.template import Template
|
||||||
|
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType, TemplateVarsType
|
||||||
|
|
||||||
from . import subscription
|
from . import subscription
|
||||||
from .config import MQTT_RW_SCHEMA
|
from .config import MQTT_RW_SCHEMA
|
||||||
@ -53,7 +56,13 @@ from .mixins import (
|
|||||||
async_setup_platform_helper,
|
async_setup_platform_helper,
|
||||||
warn_for_legacy_schema,
|
warn_for_legacy_schema,
|
||||||
)
|
)
|
||||||
from .models import MqttCommandTemplate, MqttValueTemplate
|
from .models import (
|
||||||
|
MqttCommandTemplate,
|
||||||
|
MqttValueTemplate,
|
||||||
|
PublishPayloadType,
|
||||||
|
ReceiveMessage,
|
||||||
|
ReceivePayloadType,
|
||||||
|
)
|
||||||
from .util import get_mqtt_data
|
from .util import get_mqtt_data
|
||||||
|
|
||||||
DEFAULT_NAME = "MQTT Siren"
|
DEFAULT_NAME = "MQTT Siren"
|
||||||
@ -150,8 +159,8 @@ async def _async_setup_entity(
|
|||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
async_add_entities: AddEntitiesCallback,
|
async_add_entities: AddEntitiesCallback,
|
||||||
config: ConfigType,
|
config: ConfigType,
|
||||||
config_entry: ConfigEntry | None = None,
|
config_entry: ConfigEntry,
|
||||||
discovery_data: dict | None = None,
|
discovery_data: DiscoveryInfoType | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set up the MQTT siren."""
|
"""Set up the MQTT siren."""
|
||||||
async_add_entities([MqttSiren(hass, config, config_entry, discovery_data)])
|
async_add_entities([MqttSiren(hass, config, config_entry, discovery_data)])
|
||||||
@ -162,29 +171,32 @@ class MqttSiren(MqttEntity, SirenEntity):
|
|||||||
|
|
||||||
_entity_id_format = ENTITY_ID_FORMAT
|
_entity_id_format = ENTITY_ID_FORMAT
|
||||||
_attributes_extra_blocked = MQTT_SIREN_ATTRIBUTES_BLOCKED
|
_attributes_extra_blocked = MQTT_SIREN_ATTRIBUTES_BLOCKED
|
||||||
|
_attr_supported_features: int
|
||||||
|
|
||||||
def __init__(self, hass, config, config_entry, discovery_data):
|
_command_templates: dict[
|
||||||
|
str, Callable[[PublishPayloadType, TemplateVarsType], PublishPayloadType] | None
|
||||||
|
]
|
||||||
|
_value_template: Callable[[ReceivePayloadType], ReceivePayloadType]
|
||||||
|
_state_on: str
|
||||||
|
_state_off: str
|
||||||
|
_optimistic: bool
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hass: HomeAssistant,
|
||||||
|
config: ConfigType,
|
||||||
|
config_entry: ConfigEntry,
|
||||||
|
discovery_data: DiscoveryInfoType | None,
|
||||||
|
) -> None:
|
||||||
"""Initialize the MQTT siren."""
|
"""Initialize the MQTT siren."""
|
||||||
self._attr_name = config[CONF_NAME]
|
MqttEntity.__init__(self, hass, config, config_entry, discovery_data)
|
||||||
self._attr_should_poll = False
|
|
||||||
self._supported_features = SUPPORTED_BASE
|
|
||||||
self._attr_is_on = None
|
|
||||||
self._state_on = None
|
|
||||||
self._state_off = None
|
|
||||||
self._optimistic = None
|
|
||||||
|
|
||||||
self._attr_extra_state_attributes: dict[str, Any] = {}
|
|
||||||
|
|
||||||
self.target = None
|
|
||||||
|
|
||||||
super().__init__(hass, config, config_entry, discovery_data)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def config_schema():
|
def config_schema() -> vol.Schema:
|
||||||
"""Return the config schema."""
|
"""Return the config schema."""
|
||||||
return DISCOVERY_SCHEMA
|
return DISCOVERY_SCHEMA
|
||||||
|
|
||||||
def _setup_from_config(self, config):
|
def _setup_from_config(self, config: ConfigType) -> None:
|
||||||
"""(Re)Setup the entity."""
|
"""(Re)Setup the entity."""
|
||||||
|
|
||||||
state_on = config.get(CONF_STATE_ON)
|
state_on = config.get(CONF_STATE_ON)
|
||||||
@ -193,25 +205,29 @@ class MqttSiren(MqttEntity, SirenEntity):
|
|||||||
state_off = config.get(CONF_STATE_OFF)
|
state_off = config.get(CONF_STATE_OFF)
|
||||||
self._state_off = state_off if state_off else config[CONF_PAYLOAD_OFF]
|
self._state_off = state_off if state_off else config[CONF_PAYLOAD_OFF]
|
||||||
|
|
||||||
|
self._attr_extra_state_attributes = {}
|
||||||
|
|
||||||
|
_supported_features: int = SUPPORTED_BASE
|
||||||
if config[CONF_SUPPORT_DURATION]:
|
if config[CONF_SUPPORT_DURATION]:
|
||||||
self._supported_features |= SirenEntityFeature.DURATION
|
_supported_features |= SirenEntityFeature.DURATION
|
||||||
self._attr_extra_state_attributes[ATTR_DURATION] = None
|
self._attr_extra_state_attributes[ATTR_DURATION] = None
|
||||||
|
|
||||||
if config.get(CONF_AVAILABLE_TONES):
|
if config.get(CONF_AVAILABLE_TONES):
|
||||||
self._supported_features |= SirenEntityFeature.TONES
|
_supported_features |= SirenEntityFeature.TONES
|
||||||
self._attr_available_tones = config[CONF_AVAILABLE_TONES]
|
self._attr_available_tones = config[CONF_AVAILABLE_TONES]
|
||||||
self._attr_extra_state_attributes[ATTR_TONE] = None
|
self._attr_extra_state_attributes[ATTR_TONE] = None
|
||||||
|
|
||||||
if config[CONF_SUPPORT_VOLUME_SET]:
|
if config[CONF_SUPPORT_VOLUME_SET]:
|
||||||
self._supported_features |= SirenEntityFeature.VOLUME_SET
|
_supported_features |= SirenEntityFeature.VOLUME_SET
|
||||||
self._attr_extra_state_attributes[ATTR_VOLUME_LEVEL] = None
|
self._attr_extra_state_attributes[ATTR_VOLUME_LEVEL] = None
|
||||||
|
|
||||||
|
self._attr_supported_features = _supported_features
|
||||||
self._optimistic = config[CONF_OPTIMISTIC] or CONF_STATE_TOPIC not in config
|
self._optimistic = config[CONF_OPTIMISTIC] or CONF_STATE_TOPIC not in config
|
||||||
self._attr_is_on = False if self._optimistic else None
|
self._attr_is_on = False if self._optimistic else None
|
||||||
|
|
||||||
command_template = config.get(CONF_COMMAND_TEMPLATE)
|
command_template: Template | None = config.get(CONF_COMMAND_TEMPLATE)
|
||||||
command_off_template = config.get(CONF_COMMAND_OFF_TEMPLATE) or config.get(
|
command_off_template: Template | None = (
|
||||||
CONF_COMMAND_TEMPLATE
|
config.get(CONF_COMMAND_OFF_TEMPLATE) or command_template
|
||||||
)
|
)
|
||||||
self._command_templates = {
|
self._command_templates = {
|
||||||
CONF_COMMAND_TEMPLATE: MqttCommandTemplate(
|
CONF_COMMAND_TEMPLATE: MqttCommandTemplate(
|
||||||
@ -230,12 +246,12 @@ class MqttSiren(MqttEntity, SirenEntity):
|
|||||||
entity=self,
|
entity=self,
|
||||||
).async_render_with_possible_json_value
|
).async_render_with_possible_json_value
|
||||||
|
|
||||||
def _prepare_subscribe_topics(self):
|
def _prepare_subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@log_messages(self.hass, self.entity_id)
|
@log_messages(self.hass, self.entity_id)
|
||||||
def state_message_received(msg):
|
def state_message_received(msg: ReceiveMessage) -> None:
|
||||||
"""Handle new MQTT state messages."""
|
"""Handle new MQTT state messages."""
|
||||||
payload = self._value_template(msg.payload)
|
payload = self._value_template(msg.payload)
|
||||||
if not payload or payload == PAYLOAD_EMPTY_JSON:
|
if not payload or payload == PAYLOAD_EMPTY_JSON:
|
||||||
@ -245,7 +261,7 @@ class MqttSiren(MqttEntity, SirenEntity):
|
|||||||
msg.topic,
|
msg.topic,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
json_payload = {}
|
json_payload: dict[str, Any] = {}
|
||||||
if payload in [self._state_on, self._state_off, PAYLOAD_NONE]:
|
if payload in [self._state_on, self._state_off, PAYLOAD_NONE]:
|
||||||
json_payload = {STATE: payload}
|
json_payload = {STATE: payload}
|
||||||
else:
|
else:
|
||||||
@ -275,7 +291,8 @@ class MqttSiren(MqttEntity, SirenEntity):
|
|||||||
if json_payload:
|
if json_payload:
|
||||||
# process attributes
|
# process attributes
|
||||||
try:
|
try:
|
||||||
vol.All(TURN_ON_SCHEMA)(json_payload)
|
params: SirenTurnOnServiceParameters
|
||||||
|
params = vol.All(TURN_ON_SCHEMA)(json_payload)
|
||||||
except vol.MultipleInvalid as invalid_siren_parameters:
|
except vol.MultipleInvalid as invalid_siren_parameters:
|
||||||
_LOGGER.warning(
|
_LOGGER.warning(
|
||||||
"Unable to update siren state attributes from payload '%s': %s",
|
"Unable to update siren state attributes from payload '%s': %s",
|
||||||
@ -283,7 +300,7 @@ class MqttSiren(MqttEntity, SirenEntity):
|
|||||||
invalid_siren_parameters,
|
invalid_siren_parameters,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
self._update(process_turn_on_params(self, json_payload))
|
self._update(process_turn_on_params(self, params))
|
||||||
get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
|
get_mqtt_data(self.hass).state_write_requests.write_state_request(self)
|
||||||
|
|
||||||
if self._config.get(CONF_STATE_TOPIC) is None:
|
if self._config.get(CONF_STATE_TOPIC) is None:
|
||||||
@ -303,7 +320,7 @@ class MqttSiren(MqttEntity, SirenEntity):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _subscribe_topics(self):
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
||||||
|
|
||||||
@ -322,11 +339,6 @@ class MqttSiren(MqttEntity, SirenEntity):
|
|||||||
attributes.update(self._attr_extra_state_attributes)
|
attributes.update(self._attr_extra_state_attributes)
|
||||||
return attributes
|
return attributes
|
||||||
|
|
||||||
@property
|
|
||||||
def supported_features(self) -> int:
|
|
||||||
"""Flag supported features."""
|
|
||||||
return self._supported_features
|
|
||||||
|
|
||||||
async def _async_publish(
|
async def _async_publish(
|
||||||
self,
|
self,
|
||||||
topic: str,
|
topic: str,
|
||||||
@ -335,15 +347,14 @@ class MqttSiren(MqttEntity, SirenEntity):
|
|||||||
variables: dict[str, Any] | None = None,
|
variables: dict[str, Any] | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Publish MQTT payload with optional command template."""
|
"""Publish MQTT payload with optional command template."""
|
||||||
template_variables = {STATE: value}
|
template_variables: dict[str, Any] = {STATE: value}
|
||||||
if variables is not None:
|
if variables is not None:
|
||||||
template_variables.update(variables)
|
template_variables.update(variables)
|
||||||
payload = (
|
if command_template := self._command_templates[template]:
|
||||||
self._command_templates[template](value, template_variables)
|
payload = command_template(value, template_variables)
|
||||||
if self._command_templates[template]
|
else:
|
||||||
else json_dumps(template_variables)
|
payload = json_dumps(template_variables)
|
||||||
)
|
if payload and str(payload) != PAYLOAD_NONE:
|
||||||
if payload and payload not in PAYLOAD_NONE:
|
|
||||||
await self.async_publish(
|
await self.async_publish(
|
||||||
self._config[topic],
|
self._config[topic],
|
||||||
payload,
|
payload,
|
||||||
@ -367,7 +378,7 @@ class MqttSiren(MqttEntity, SirenEntity):
|
|||||||
# Optimistically assume that siren has changed state.
|
# Optimistically assume that siren has changed state.
|
||||||
_LOGGER.debug("Writing state attributes %s", kwargs)
|
_LOGGER.debug("Writing state attributes %s", kwargs)
|
||||||
self._attr_is_on = True
|
self._attr_is_on = True
|
||||||
self._update(kwargs)
|
self._update(cast(SirenTurnOnServiceParameters, kwargs))
|
||||||
self.async_write_ha_state()
|
self.async_write_ha_state()
|
||||||
|
|
||||||
async def async_turn_off(self, **kwargs: Any) -> None:
|
async def async_turn_off(self, **kwargs: Any) -> None:
|
||||||
@ -386,8 +397,8 @@ class MqttSiren(MqttEntity, SirenEntity):
|
|||||||
self._attr_is_on = False
|
self._attr_is_on = False
|
||||||
self.async_write_ha_state()
|
self.async_write_ha_state()
|
||||||
|
|
||||||
def _update(self, data: dict[str, Any]) -> None:
|
def _update(self, data: SirenTurnOnServiceParameters) -> None:
|
||||||
"""Update the extra siren state attributes."""
|
"""Update the extra siren state attributes."""
|
||||||
for attribute, support in SUPPORTED_ATTRIBUTES.items():
|
for attribute, support in SUPPORTED_ATTRIBUTES.items():
|
||||||
if self._supported_features & support and attribute in data:
|
if self._attr_supported_features & support and attribute in data:
|
||||||
self._attr_extra_state_attributes[attribute] = data[attribute]
|
self._attr_extra_state_attributes[attribute] = data[attribute] # type: ignore[literal-required]
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
"""Support for MQTT switches."""
|
"""Support for MQTT switches."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import Callable
|
||||||
import functools
|
import functools
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@ -22,6 +23,7 @@ from homeassistant.core import HomeAssistant, callback
|
|||||||
import homeassistant.helpers.config_validation as cv
|
import homeassistant.helpers.config_validation as cv
|
||||||
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
from homeassistant.helpers.entity_platform import AddEntitiesCallback
|
||||||
from homeassistant.helpers.restore_state import RestoreEntity
|
from homeassistant.helpers.restore_state import RestoreEntity
|
||||||
|
from homeassistant.helpers.service_info.mqtt import ReceivePayloadType
|
||||||
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType
|
||||||
|
|
||||||
from . import subscription
|
from . import subscription
|
||||||
@ -42,7 +44,7 @@ from .mixins import (
|
|||||||
async_setup_platform_helper,
|
async_setup_platform_helper,
|
||||||
warn_for_legacy_schema,
|
warn_for_legacy_schema,
|
||||||
)
|
)
|
||||||
from .models import MqttValueTemplate
|
from .models import MqttValueTemplate, ReceiveMessage
|
||||||
from .util import get_mqtt_data
|
from .util import get_mqtt_data
|
||||||
|
|
||||||
DEFAULT_NAME = "MQTT Switch"
|
DEFAULT_NAME = "MQTT Switch"
|
||||||
@ -107,8 +109,8 @@ async def _async_setup_entity(
|
|||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
async_add_entities: AddEntitiesCallback,
|
async_add_entities: AddEntitiesCallback,
|
||||||
config: ConfigType,
|
config: ConfigType,
|
||||||
config_entry: ConfigEntry | None = None,
|
config_entry: ConfigEntry,
|
||||||
discovery_data: dict | None = None,
|
discovery_data: DiscoveryInfoType | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set up the MQTT switch."""
|
"""Set up the MQTT switch."""
|
||||||
async_add_entities([MqttSwitch(hass, config, config_entry, discovery_data)])
|
async_add_entities([MqttSwitch(hass, config, config_entry, discovery_data)])
|
||||||
@ -119,16 +121,24 @@ class MqttSwitch(MqttEntity, SwitchEntity, RestoreEntity):
|
|||||||
|
|
||||||
_entity_id_format = switch.ENTITY_ID_FORMAT
|
_entity_id_format = switch.ENTITY_ID_FORMAT
|
||||||
|
|
||||||
def __init__(self, hass, config, config_entry, discovery_data):
|
_optimistic: bool
|
||||||
|
_state_on: str
|
||||||
|
_state_off: str
|
||||||
|
_value_template: Callable[[ReceivePayloadType], ReceivePayloadType]
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
hass: HomeAssistant,
|
||||||
|
config: ConfigType,
|
||||||
|
config_entry: ConfigEntry,
|
||||||
|
discovery_data: DiscoveryInfoType | None,
|
||||||
|
) -> None:
|
||||||
"""Initialize the MQTT switch."""
|
"""Initialize the MQTT switch."""
|
||||||
self._state_on = None
|
|
||||||
self._state_off = None
|
|
||||||
self._optimistic = None
|
|
||||||
|
|
||||||
MqttEntity.__init__(self, hass, config, config_entry, discovery_data)
|
MqttEntity.__init__(self, hass, config, config_entry, discovery_data)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def config_schema():
|
def config_schema() -> vol.Schema:
|
||||||
"""Return the config schema."""
|
"""Return the config schema."""
|
||||||
return DISCOVERY_SCHEMA
|
return DISCOVERY_SCHEMA
|
||||||
|
|
||||||
@ -136,10 +146,10 @@ class MqttSwitch(MqttEntity, SwitchEntity, RestoreEntity):
|
|||||||
"""(Re)Setup the entity."""
|
"""(Re)Setup the entity."""
|
||||||
self._attr_device_class = config.get(CONF_DEVICE_CLASS)
|
self._attr_device_class = config.get(CONF_DEVICE_CLASS)
|
||||||
|
|
||||||
state_on = config.get(CONF_STATE_ON)
|
state_on: str | None = config.get(CONF_STATE_ON)
|
||||||
self._state_on = state_on if state_on else config[CONF_PAYLOAD_ON]
|
self._state_on = state_on if state_on else config[CONF_PAYLOAD_ON]
|
||||||
|
|
||||||
state_off = config.get(CONF_STATE_OFF)
|
state_off: str | None = config.get(CONF_STATE_OFF)
|
||||||
self._state_off = state_off if state_off else config[CONF_PAYLOAD_OFF]
|
self._state_off = state_off if state_off else config[CONF_PAYLOAD_OFF]
|
||||||
|
|
||||||
self._optimistic = (
|
self._optimistic = (
|
||||||
@ -150,12 +160,12 @@ class MqttSwitch(MqttEntity, SwitchEntity, RestoreEntity):
|
|||||||
self._config.get(CONF_VALUE_TEMPLATE), entity=self
|
self._config.get(CONF_VALUE_TEMPLATE), entity=self
|
||||||
).async_render_with_possible_json_value
|
).async_render_with_possible_json_value
|
||||||
|
|
||||||
def _prepare_subscribe_topics(self):
|
def _prepare_subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@log_messages(self.hass, self.entity_id)
|
@log_messages(self.hass, self.entity_id)
|
||||||
def state_message_received(msg):
|
def state_message_received(msg: ReceiveMessage) -> None:
|
||||||
"""Handle new MQTT state messages."""
|
"""Handle new MQTT state messages."""
|
||||||
payload = self._value_template(msg.payload)
|
payload = self._value_template(msg.payload)
|
||||||
if payload == self._state_on:
|
if payload == self._state_on:
|
||||||
@ -184,7 +194,7 @@ class MqttSwitch(MqttEntity, SwitchEntity, RestoreEntity):
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _subscribe_topics(self):
|
async def _subscribe_topics(self) -> None:
|
||||||
"""(Re)Subscribe to topics."""
|
"""(Re)Subscribe to topics."""
|
||||||
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
await subscription.async_subscribe_topics(self.hass, self._sub_state)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user