mirror of
https://github.com/home-assistant/core.git
synced 2025-07-20 03:37:07 +00:00
Enable basic type checking for mqtt (#52463)
* Enable basic type checking for mqtt * Tweak
This commit is contained in:
parent
5321151799
commit
79ee112490
@ -9,7 +9,7 @@ import logging
|
||||
from operator import attrgetter
|
||||
import ssl
|
||||
import time
|
||||
from typing import Any, Callable, Union
|
||||
from typing import Any, Awaitable, Callable, Union, cast
|
||||
import uuid
|
||||
|
||||
import attr
|
||||
@ -73,7 +73,12 @@ from .const import (
|
||||
PROTOCOL_311,
|
||||
)
|
||||
from .discovery import LAST_DISCOVERY
|
||||
from .models import Message, MessageCallbackType, PublishPayloadType
|
||||
from .models import (
|
||||
AsyncMessageCallbackType,
|
||||
Message,
|
||||
MessageCallbackType,
|
||||
PublishPayloadType,
|
||||
)
|
||||
from .util import _VALID_QOS_SCHEMA, valid_publish_topic, valid_subscribe_topic
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
@ -284,26 +289,36 @@ def async_publish_template(
|
||||
hass.async_create_task(hass.services.async_call(DOMAIN, SERVICE_PUBLISH, data))
|
||||
|
||||
|
||||
def wrap_msg_callback(msg_callback: MessageCallbackType) -> MessageCallbackType:
|
||||
AsyncDeprecatedMessageCallbackType = Callable[
|
||||
[str, PublishPayloadType, int], Awaitable[None]
|
||||
]
|
||||
DeprecatedMessageCallbackType = Callable[[str, PublishPayloadType, int], None]
|
||||
|
||||
|
||||
def wrap_msg_callback(
|
||||
msg_callback: AsyncDeprecatedMessageCallbackType | DeprecatedMessageCallbackType,
|
||||
) -> AsyncMessageCallbackType | MessageCallbackType:
|
||||
"""Wrap an MQTT message callback to support deprecated signature."""
|
||||
# Check for partials to properly determine if coroutine function
|
||||
check_func = msg_callback
|
||||
while isinstance(check_func, partial):
|
||||
check_func = check_func.func
|
||||
|
||||
wrapper_func = None
|
||||
wrapper_func: AsyncMessageCallbackType | MessageCallbackType
|
||||
if asyncio.iscoroutinefunction(check_func):
|
||||
|
||||
@wraps(msg_callback)
|
||||
async def async_wrapper(msg: Any) -> None:
|
||||
async def async_wrapper(msg: Message) -> None:
|
||||
"""Call with deprecated signature."""
|
||||
await msg_callback(msg.topic, msg.payload, msg.qos)
|
||||
await cast(AsyncDeprecatedMessageCallbackType, msg_callback)(
|
||||
msg.topic, msg.payload, msg.qos
|
||||
)
|
||||
|
||||
wrapper_func = async_wrapper
|
||||
else:
|
||||
|
||||
@wraps(msg_callback)
|
||||
def wrapper(msg: Any) -> None:
|
||||
def wrapper(msg: Message) -> None:
|
||||
"""Call with deprecated signature."""
|
||||
msg_callback(msg.topic, msg.payload, msg.qos)
|
||||
|
||||
@ -315,7 +330,10 @@ def wrap_msg_callback(msg_callback: MessageCallbackType) -> MessageCallbackType:
|
||||
async def async_subscribe(
|
||||
hass: HomeAssistant,
|
||||
topic: str,
|
||||
msg_callback: MessageCallbackType,
|
||||
msg_callback: AsyncMessageCallbackType
|
||||
| MessageCallbackType
|
||||
| DeprecatedMessageCallbackType
|
||||
| AsyncDeprecatedMessageCallbackType,
|
||||
qos: int = DEFAULT_QOS,
|
||||
encoding: str | None = "utf-8",
|
||||
):
|
||||
@ -334,12 +352,15 @@ async def async_subscribe(
|
||||
wrapped_msg_callback = msg_callback
|
||||
# If we have 3 parameters with no default value, wrap the callback
|
||||
if non_default == 3:
|
||||
module = inspect.getmodule(msg_callback)
|
||||
_LOGGER.warning(
|
||||
"Signature of MQTT msg_callback '%s.%s' is deprecated",
|
||||
inspect.getmodule(msg_callback).__name__,
|
||||
module.__name__ if module else "<unknown>",
|
||||
msg_callback.__name__,
|
||||
)
|
||||
wrapped_msg_callback = wrap_msg_callback(msg_callback)
|
||||
wrapped_msg_callback = wrap_msg_callback(
|
||||
cast(DeprecatedMessageCallbackType, msg_callback)
|
||||
)
|
||||
|
||||
async_remove = await hass.data[DATA_MQTT].async_subscribe(
|
||||
topic,
|
||||
@ -378,16 +399,12 @@ def subscribe(
|
||||
|
||||
async def _async_setup_discovery(
|
||||
hass: HomeAssistant, conf: ConfigType, config_entry
|
||||
) -> bool:
|
||||
) -> None:
|
||||
"""Try to start the discovery of MQTT devices.
|
||||
|
||||
This method is a coroutine.
|
||||
"""
|
||||
success: bool = await discovery.async_start(
|
||||
hass, conf[CONF_DISCOVERY_PREFIX], config_entry
|
||||
)
|
||||
|
||||
return success
|
||||
await discovery.async_start(hass, conf[CONF_DISCOVERY_PREFIX], config_entry)
|
||||
|
||||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
@ -539,7 +556,7 @@ class Subscription:
|
||||
matcher: Any = attr.ib()
|
||||
job: HassJob = attr.ib()
|
||||
qos: int = attr.ib(default=0)
|
||||
encoding: str = attr.ib(default="utf-8")
|
||||
encoding: str | None = attr.ib(default="utf-8")
|
||||
|
||||
|
||||
class MQTT:
|
||||
@ -566,7 +583,7 @@ class MQTT:
|
||||
self._mqttc: mqtt.Client = None
|
||||
self._paho_lock = asyncio.Lock()
|
||||
|
||||
self._pending_operations = {}
|
||||
self._pending_operations: dict[str, asyncio.Event] = {}
|
||||
|
||||
if self.hass.state == CoreState.running:
|
||||
self._ha_started.set()
|
||||
@ -688,12 +705,12 @@ class MQTT:
|
||||
_raise_on_error(msg_info.rc)
|
||||
await self._wait_for_mid(msg_info.mid)
|
||||
|
||||
async def async_connect(self) -> str:
|
||||
async def async_connect(self) -> None:
|
||||
"""Connect to the host. Does not process messages yet."""
|
||||
# pylint: disable=import-outside-toplevel
|
||||
import paho.mqtt.client as mqtt
|
||||
|
||||
result: int = None
|
||||
result: int | None = None
|
||||
try:
|
||||
result = await self.hass.async_add_executor_job(
|
||||
self._mqttc.connect,
|
||||
@ -770,7 +787,7 @@ class MQTT:
|
||||
This method is a coroutine.
|
||||
"""
|
||||
async with self._paho_lock:
|
||||
result: int = None
|
||||
result: int | None = None
|
||||
result, mid = await self.hass.async_add_executor_job(
|
||||
self._mqttc.unsubscribe, topic
|
||||
)
|
||||
@ -781,7 +798,7 @@ class MQTT:
|
||||
async def _async_perform_subscription(self, topic: str, qos: int) -> None:
|
||||
"""Perform a paho-mqtt subscription."""
|
||||
async with self._paho_lock:
|
||||
result: int = None
|
||||
result: int | None = None
|
||||
result, mid = await self.hass.async_add_executor_job(
|
||||
self._mqttc.subscribe, topic, qos
|
||||
)
|
||||
@ -952,12 +969,12 @@ class MQTT:
|
||||
)
|
||||
|
||||
|
||||
def _raise_on_error(result_code: int) -> None:
|
||||
def _raise_on_error(result_code: int | None) -> None:
|
||||
"""Raise error if error result."""
|
||||
# pylint: disable=import-outside-toplevel
|
||||
import paho.mqtt.client as mqtt
|
||||
|
||||
if result_code != 0:
|
||||
if result_code is not None and result_code != 0:
|
||||
raise HomeAssistantError(
|
||||
f"Error talking to MQTT: {mqtt.error_string(result_code)}"
|
||||
)
|
||||
@ -1014,13 +1031,13 @@ async def websocket_remove_device(hass, connection, msg):
|
||||
)
|
||||
|
||||
|
||||
@websocket_api.async_response
|
||||
@websocket_api.websocket_command(
|
||||
{
|
||||
vol.Required("type"): "mqtt/subscribe",
|
||||
vol.Required("topic"): valid_subscribe_topic,
|
||||
}
|
||||
)
|
||||
@websocket_api.async_response
|
||||
async def websocket_subscribe(hass, connection, msg):
|
||||
"""Subscribe to a MQTT topic."""
|
||||
if not connection.user.is_admin:
|
||||
|
@ -226,6 +226,7 @@ class MqttBinarySensor(MqttEntity, BinarySensorEntity):
|
||||
def available(self) -> bool:
|
||||
"""Return true if the device is available and value has not expired."""
|
||||
expire_after = self._config.get(CONF_EXPIRE_AFTER)
|
||||
return MqttAvailability.available.fget(self) and (
|
||||
# mypy doesn't know about fget: https://github.com/python/mypy/issues/6185
|
||||
return MqttAvailability.available.fget(self) and ( # type: ignore[attr-defined]
|
||||
expire_after is None or not self._expired
|
||||
)
|
||||
|
@ -1,7 +1,7 @@
|
||||
"""Helper to handle a set of topics to subscribe to."""
|
||||
from collections import deque
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
from typing import Any, Callable
|
||||
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
@ -12,7 +12,9 @@ DATA_MQTT_DEBUG_INFO = "mqtt_debug_info"
|
||||
STORED_MESSAGES = 10
|
||||
|
||||
|
||||
def log_messages(hass: HomeAssistant, entity_id: str) -> MessageCallbackType:
|
||||
def log_messages(
|
||||
hass: HomeAssistant, entity_id: str
|
||||
) -> Callable[[MessageCallbackType], MessageCallbackType]:
|
||||
"""Wrap an MQTT message callback to support message logging."""
|
||||
|
||||
def _log_message(msg):
|
||||
@ -24,7 +26,7 @@ def log_messages(hass: HomeAssistant, entity_id: str) -> MessageCallbackType:
|
||||
if msg not in messages:
|
||||
messages.append(msg)
|
||||
|
||||
def _decorator(msg_callback: MessageCallbackType):
|
||||
def _decorator(msg_callback: MessageCallbackType) -> MessageCallbackType:
|
||||
@wraps(msg_callback)
|
||||
def wrapper(msg: Any) -> None:
|
||||
"""Log message."""
|
||||
|
@ -119,15 +119,15 @@ class Trigger:
|
||||
"""Device trigger settings."""
|
||||
|
||||
device_id: str = attr.ib()
|
||||
discovery_data: dict = attr.ib()
|
||||
discovery_data: dict | None = attr.ib()
|
||||
hass: HomeAssistant = attr.ib()
|
||||
payload: str = attr.ib()
|
||||
qos: int = attr.ib()
|
||||
remove_signal: Callable[[], None] = attr.ib()
|
||||
payload: str | None = attr.ib()
|
||||
qos: int | None = attr.ib()
|
||||
remove_signal: Callable[[], None] | None = attr.ib()
|
||||
subtype: str = attr.ib()
|
||||
topic: str = attr.ib()
|
||||
topic: str | None = attr.ib()
|
||||
type: str = attr.ib()
|
||||
value_template: str = attr.ib()
|
||||
value_template: str | None = attr.ib()
|
||||
trigger_instances: list[TriggerInstance] = attr.ib(factory=list)
|
||||
|
||||
async def add_trigger(self, action, automation_info):
|
||||
@ -289,7 +289,7 @@ async def async_device_removed(hass: HomeAssistant, device_id: str):
|
||||
|
||||
async def async_get_triggers(hass: HomeAssistant, device_id: str) -> list[dict]:
|
||||
"""List device triggers for MQTT devices."""
|
||||
triggers = []
|
||||
triggers: list[dict] = []
|
||||
|
||||
if DEVICE_TRIGGERS not in hass.data:
|
||||
return triggers
|
||||
|
@ -83,7 +83,7 @@ class MQTTConfig(dict):
|
||||
|
||||
async def async_start( # noqa: C901
|
||||
hass: HomeAssistant, discovery_topic, config_entry=None
|
||||
) -> bool:
|
||||
) -> None:
|
||||
"""Start MQTT Discovery."""
|
||||
mqtt_integrations = {}
|
||||
|
||||
@ -298,10 +298,8 @@ async def async_start( # noqa: C901
|
||||
0,
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def async_stop(hass: HomeAssistant) -> bool:
|
||||
async def async_stop(hass: HomeAssistant) -> None:
|
||||
"""Stop MQTT Discovery."""
|
||||
if DISCOVERY_UNSUBSCRIBE in hass.data:
|
||||
for unsub in hass.data[DISCOVERY_UNSUBSCRIBE]:
|
||||
|
@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
from abc import abstractmethod
|
||||
import json
|
||||
import logging
|
||||
from typing import Callable
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
@ -194,11 +195,11 @@ async def async_setup_entry_helper(hass, domain, async_setup, schema):
|
||||
class MqttAttributes(Entity):
|
||||
"""Mixin used for platforms that support JSON attributes."""
|
||||
|
||||
_attributes_extra_blocked = frozenset()
|
||||
_attributes_extra_blocked: frozenset[str] = frozenset()
|
||||
|
||||
def __init__(self, config: dict) -> None:
|
||||
"""Initialize the JSON attributes mixin."""
|
||||
self._attributes = None
|
||||
self._attributes: dict | None = None
|
||||
self._attributes_sub_state = None
|
||||
self._attributes_config = config
|
||||
|
||||
@ -225,7 +226,7 @@ class MqttAttributes(Entity):
|
||||
payload = msg.payload
|
||||
if attr_tpl is not None:
|
||||
payload = attr_tpl.async_render_with_possible_json_value(payload)
|
||||
json_dict = json.loads(payload)
|
||||
json_dict = json.loads(payload) if isinstance(payload, str) else None
|
||||
if isinstance(json_dict, dict):
|
||||
filtered_dict = {
|
||||
k: v
|
||||
@ -272,7 +273,7 @@ class MqttAvailability(Entity):
|
||||
def __init__(self, config: dict) -> None:
|
||||
"""Initialize the availability mixin."""
|
||||
self._availability_sub_state = None
|
||||
self._available = {}
|
||||
self._available: dict = {}
|
||||
self._available_latest = False
|
||||
self._availability_setup_from_config(config)
|
||||
|
||||
@ -397,7 +398,7 @@ class MqttDiscoveryUpdate(Entity):
|
||||
"""Initialize the discovery update mixin."""
|
||||
self._discovery_data = discovery_data
|
||||
self._discovery_update = discovery_update
|
||||
self._remove_signal = None
|
||||
self._remove_signal: Callable | None = None
|
||||
self._removed_from_hass = False
|
||||
|
||||
async def async_added_to_hass(self) -> None:
|
||||
|
@ -2,7 +2,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime as dt
|
||||
from typing import Callable, Union
|
||||
from typing import Awaitable, Callable, Union
|
||||
|
||||
import attr
|
||||
|
||||
@ -21,4 +21,5 @@ class Message:
|
||||
timestamp: dt.datetime | None = attr.ib(default=None)
|
||||
|
||||
|
||||
AsyncMessageCallbackType = Callable[[Message], Awaitable[None]]
|
||||
MessageCallbackType = Callable[[Message], None]
|
||||
|
@ -242,6 +242,7 @@ class MqttSensor(MqttEntity, SensorEntity):
|
||||
def available(self) -> bool:
|
||||
"""Return true if the device is available and value has not expired."""
|
||||
expire_after = self._config.get(CONF_EXPIRE_AFTER)
|
||||
return MqttAvailability.available.fget(self) and (
|
||||
# mypy doesn't know about fget: https://github.com/python/mypy/issues/6185
|
||||
return MqttAvailability.available.fget(self) and ( # type: ignore[attr-defined]
|
||||
expire_after is None or not self._expired
|
||||
)
|
||||
|
3
mypy.ini
3
mypy.ini
@ -1345,9 +1345,6 @@ ignore_errors = true
|
||||
[mypy-homeassistant.components.motion_blinds.*]
|
||||
ignore_errors = true
|
||||
|
||||
[mypy-homeassistant.components.mqtt.*]
|
||||
ignore_errors = true
|
||||
|
||||
[mypy-homeassistant.components.mullvad.*]
|
||||
ignore_errors = true
|
||||
|
||||
|
@ -125,7 +125,6 @@ IGNORED_MODULES: Final[list[str]] = [
|
||||
"homeassistant.components.minecraft_server.*",
|
||||
"homeassistant.components.mobile_app.*",
|
||||
"homeassistant.components.motion_blinds.*",
|
||||
"homeassistant.components.mqtt.*",
|
||||
"homeassistant.components.mullvad.*",
|
||||
"homeassistant.components.neato.*",
|
||||
"homeassistant.components.ness_alarm.*",
|
||||
|
Loading…
x
Reference in New Issue
Block a user