Enable basic type checking for mqtt (#52463)

* Enable basic type checking for mqtt

* Tweak
This commit is contained in:
Erik Montnemery 2021-07-05 10:33:12 +02:00 committed by GitHub
parent 5321151799
commit 79ee112490
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 68 additions and 51 deletions

View File

@ -9,7 +9,7 @@ import logging
from operator import attrgetter
import ssl
import time
from typing import Any, Callable, Union
from typing import Any, Awaitable, Callable, Union, cast
import uuid
import attr
@ -73,7 +73,12 @@ from .const import (
PROTOCOL_311,
)
from .discovery import LAST_DISCOVERY
from .models import Message, MessageCallbackType, PublishPayloadType
from .models import (
AsyncMessageCallbackType,
Message,
MessageCallbackType,
PublishPayloadType,
)
from .util import _VALID_QOS_SCHEMA, valid_publish_topic, valid_subscribe_topic
_LOGGER = logging.getLogger(__name__)
@ -284,26 +289,36 @@ def async_publish_template(
hass.async_create_task(hass.services.async_call(DOMAIN, SERVICE_PUBLISH, data))
def wrap_msg_callback(msg_callback: MessageCallbackType) -> MessageCallbackType:
AsyncDeprecatedMessageCallbackType = Callable[
[str, PublishPayloadType, int], Awaitable[None]
]
DeprecatedMessageCallbackType = Callable[[str, PublishPayloadType, int], None]
def wrap_msg_callback(
msg_callback: AsyncDeprecatedMessageCallbackType | DeprecatedMessageCallbackType,
) -> AsyncMessageCallbackType | MessageCallbackType:
"""Wrap an MQTT message callback to support deprecated signature."""
# Check for partials to properly determine if coroutine function
check_func = msg_callback
while isinstance(check_func, partial):
check_func = check_func.func
wrapper_func = None
wrapper_func: AsyncMessageCallbackType | MessageCallbackType
if asyncio.iscoroutinefunction(check_func):
@wraps(msg_callback)
async def async_wrapper(msg: Any) -> None:
async def async_wrapper(msg: Message) -> None:
"""Call with deprecated signature."""
await msg_callback(msg.topic, msg.payload, msg.qos)
await cast(AsyncDeprecatedMessageCallbackType, msg_callback)(
msg.topic, msg.payload, msg.qos
)
wrapper_func = async_wrapper
else:
@wraps(msg_callback)
def wrapper(msg: Any) -> None:
def wrapper(msg: Message) -> None:
"""Call with deprecated signature."""
msg_callback(msg.topic, msg.payload, msg.qos)
@ -315,7 +330,10 @@ def wrap_msg_callback(msg_callback: MessageCallbackType) -> MessageCallbackType:
async def async_subscribe(
hass: HomeAssistant,
topic: str,
msg_callback: MessageCallbackType,
msg_callback: AsyncMessageCallbackType
| MessageCallbackType
| DeprecatedMessageCallbackType
| AsyncDeprecatedMessageCallbackType,
qos: int = DEFAULT_QOS,
encoding: str | None = "utf-8",
):
@ -334,12 +352,15 @@ async def async_subscribe(
wrapped_msg_callback = msg_callback
# If we have 3 parameters with no default value, wrap the callback
if non_default == 3:
module = inspect.getmodule(msg_callback)
_LOGGER.warning(
"Signature of MQTT msg_callback '%s.%s' is deprecated",
inspect.getmodule(msg_callback).__name__,
module.__name__ if module else "<unknown>",
msg_callback.__name__,
)
wrapped_msg_callback = wrap_msg_callback(msg_callback)
wrapped_msg_callback = wrap_msg_callback(
cast(DeprecatedMessageCallbackType, msg_callback)
)
async_remove = await hass.data[DATA_MQTT].async_subscribe(
topic,
@ -378,16 +399,12 @@ def subscribe(
async def _async_setup_discovery(
hass: HomeAssistant, conf: ConfigType, config_entry
) -> bool:
) -> None:
"""Try to start the discovery of MQTT devices.
This method is a coroutine.
"""
success: bool = await discovery.async_start(
hass, conf[CONF_DISCOVERY_PREFIX], config_entry
)
return success
await discovery.async_start(hass, conf[CONF_DISCOVERY_PREFIX], config_entry)
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
@ -539,7 +556,7 @@ class Subscription:
matcher: Any = attr.ib()
job: HassJob = attr.ib()
qos: int = attr.ib(default=0)
encoding: str = attr.ib(default="utf-8")
encoding: str | None = attr.ib(default="utf-8")
class MQTT:
@ -566,7 +583,7 @@ class MQTT:
self._mqttc: mqtt.Client = None
self._paho_lock = asyncio.Lock()
self._pending_operations = {}
self._pending_operations: dict[str, asyncio.Event] = {}
if self.hass.state == CoreState.running:
self._ha_started.set()
@ -688,12 +705,12 @@ class MQTT:
_raise_on_error(msg_info.rc)
await self._wait_for_mid(msg_info.mid)
async def async_connect(self) -> str:
async def async_connect(self) -> None:
"""Connect to the host. Does not process messages yet."""
# pylint: disable=import-outside-toplevel
import paho.mqtt.client as mqtt
result: int = None
result: int | None = None
try:
result = await self.hass.async_add_executor_job(
self._mqttc.connect,
@ -770,7 +787,7 @@ class MQTT:
This method is a coroutine.
"""
async with self._paho_lock:
result: int = None
result: int | None = None
result, mid = await self.hass.async_add_executor_job(
self._mqttc.unsubscribe, topic
)
@ -781,7 +798,7 @@ class MQTT:
async def _async_perform_subscription(self, topic: str, qos: int) -> None:
"""Perform a paho-mqtt subscription."""
async with self._paho_lock:
result: int = None
result: int | None = None
result, mid = await self.hass.async_add_executor_job(
self._mqttc.subscribe, topic, qos
)
@ -952,12 +969,12 @@ class MQTT:
)
def _raise_on_error(result_code: int) -> None:
def _raise_on_error(result_code: int | None) -> None:
"""Raise error if error result."""
# pylint: disable=import-outside-toplevel
import paho.mqtt.client as mqtt
if result_code != 0:
if result_code is not None and result_code != 0:
raise HomeAssistantError(
f"Error talking to MQTT: {mqtt.error_string(result_code)}"
)
@ -1014,13 +1031,13 @@ async def websocket_remove_device(hass, connection, msg):
)
@websocket_api.async_response
@websocket_api.websocket_command(
{
vol.Required("type"): "mqtt/subscribe",
vol.Required("topic"): valid_subscribe_topic,
}
)
@websocket_api.async_response
async def websocket_subscribe(hass, connection, msg):
"""Subscribe to a MQTT topic."""
if not connection.user.is_admin:

View File

@ -226,6 +226,7 @@ class MqttBinarySensor(MqttEntity, BinarySensorEntity):
def available(self) -> bool:
"""Return true if the device is available and value has not expired."""
expire_after = self._config.get(CONF_EXPIRE_AFTER)
return MqttAvailability.available.fget(self) and (
# mypy doesn't know about fget: https://github.com/python/mypy/issues/6185
return MqttAvailability.available.fget(self) and ( # type: ignore[attr-defined]
expire_after is None or not self._expired
)

View File

@ -1,7 +1,7 @@
"""Helper to handle a set of topics to subscribe to."""
from collections import deque
from functools import wraps
from typing import Any
from typing import Any, Callable
from homeassistant.core import HomeAssistant
@ -12,7 +12,9 @@ DATA_MQTT_DEBUG_INFO = "mqtt_debug_info"
STORED_MESSAGES = 10
def log_messages(hass: HomeAssistant, entity_id: str) -> MessageCallbackType:
def log_messages(
hass: HomeAssistant, entity_id: str
) -> Callable[[MessageCallbackType], MessageCallbackType]:
"""Wrap an MQTT message callback to support message logging."""
def _log_message(msg):
@ -24,7 +26,7 @@ def log_messages(hass: HomeAssistant, entity_id: str) -> MessageCallbackType:
if msg not in messages:
messages.append(msg)
def _decorator(msg_callback: MessageCallbackType):
def _decorator(msg_callback: MessageCallbackType) -> MessageCallbackType:
@wraps(msg_callback)
def wrapper(msg: Any) -> None:
"""Log message."""

View File

@ -119,15 +119,15 @@ class Trigger:
"""Device trigger settings."""
device_id: str = attr.ib()
discovery_data: dict = attr.ib()
discovery_data: dict | None = attr.ib()
hass: HomeAssistant = attr.ib()
payload: str = attr.ib()
qos: int = attr.ib()
remove_signal: Callable[[], None] = attr.ib()
payload: str | None = attr.ib()
qos: int | None = attr.ib()
remove_signal: Callable[[], None] | None = attr.ib()
subtype: str = attr.ib()
topic: str = attr.ib()
topic: str | None = attr.ib()
type: str = attr.ib()
value_template: str = attr.ib()
value_template: str | None = attr.ib()
trigger_instances: list[TriggerInstance] = attr.ib(factory=list)
async def add_trigger(self, action, automation_info):
@ -289,7 +289,7 @@ async def async_device_removed(hass: HomeAssistant, device_id: str):
async def async_get_triggers(hass: HomeAssistant, device_id: str) -> list[dict]:
"""List device triggers for MQTT devices."""
triggers = []
triggers: list[dict] = []
if DEVICE_TRIGGERS not in hass.data:
return triggers

View File

@ -83,7 +83,7 @@ class MQTTConfig(dict):
async def async_start( # noqa: C901
hass: HomeAssistant, discovery_topic, config_entry=None
) -> bool:
) -> None:
"""Start MQTT Discovery."""
mqtt_integrations = {}
@ -298,10 +298,8 @@ async def async_start( # noqa: C901
0,
)
return True
async def async_stop(hass: HomeAssistant) -> bool:
async def async_stop(hass: HomeAssistant) -> None:
"""Stop MQTT Discovery."""
if DISCOVERY_UNSUBSCRIBE in hass.data:
for unsub in hass.data[DISCOVERY_UNSUBSCRIBE]:

View File

@ -4,6 +4,7 @@ from __future__ import annotations
from abc import abstractmethod
import json
import logging
from typing import Callable
import voluptuous as vol
@ -194,11 +195,11 @@ async def async_setup_entry_helper(hass, domain, async_setup, schema):
class MqttAttributes(Entity):
"""Mixin used for platforms that support JSON attributes."""
_attributes_extra_blocked = frozenset()
_attributes_extra_blocked: frozenset[str] = frozenset()
def __init__(self, config: dict) -> None:
"""Initialize the JSON attributes mixin."""
self._attributes = None
self._attributes: dict | None = None
self._attributes_sub_state = None
self._attributes_config = config
@ -225,7 +226,7 @@ class MqttAttributes(Entity):
payload = msg.payload
if attr_tpl is not None:
payload = attr_tpl.async_render_with_possible_json_value(payload)
json_dict = json.loads(payload)
json_dict = json.loads(payload) if isinstance(payload, str) else None
if isinstance(json_dict, dict):
filtered_dict = {
k: v
@ -272,7 +273,7 @@ class MqttAvailability(Entity):
def __init__(self, config: dict) -> None:
"""Initialize the availability mixin."""
self._availability_sub_state = None
self._available = {}
self._available: dict = {}
self._available_latest = False
self._availability_setup_from_config(config)
@ -397,7 +398,7 @@ class MqttDiscoveryUpdate(Entity):
"""Initialize the discovery update mixin."""
self._discovery_data = discovery_data
self._discovery_update = discovery_update
self._remove_signal = None
self._remove_signal: Callable | None = None
self._removed_from_hass = False
async def async_added_to_hass(self) -> None:

View File

@ -2,7 +2,7 @@
from __future__ import annotations
import datetime as dt
from typing import Callable, Union
from typing import Awaitable, Callable, Union
import attr
@ -21,4 +21,5 @@ class Message:
timestamp: dt.datetime | None = attr.ib(default=None)
AsyncMessageCallbackType = Callable[[Message], Awaitable[None]]
MessageCallbackType = Callable[[Message], None]

View File

@ -242,6 +242,7 @@ class MqttSensor(MqttEntity, SensorEntity):
def available(self) -> bool:
"""Return true if the device is available and value has not expired."""
expire_after = self._config.get(CONF_EXPIRE_AFTER)
return MqttAvailability.available.fget(self) and (
# mypy doesn't know about fget: https://github.com/python/mypy/issues/6185
return MqttAvailability.available.fget(self) and ( # type: ignore[attr-defined]
expire_after is None or not self._expired
)

View File

@ -1345,9 +1345,6 @@ ignore_errors = true
[mypy-homeassistant.components.motion_blinds.*]
ignore_errors = true
[mypy-homeassistant.components.mqtt.*]
ignore_errors = true
[mypy-homeassistant.components.mullvad.*]
ignore_errors = true

View File

@ -125,7 +125,6 @@ IGNORED_MODULES: Final[list[str]] = [
"homeassistant.components.minecraft_server.*",
"homeassistant.components.mobile_app.*",
"homeassistant.components.motion_blinds.*",
"homeassistant.components.mqtt.*",
"homeassistant.components.mullvad.*",
"homeassistant.components.neato.*",
"homeassistant.components.ness_alarm.*",