From c1aaef28a9b84ad93d7face467a679861043d166 Mon Sep 17 00:00:00 2001 From: Otto Winter Date: Wed, 28 Feb 2018 22:59:14 +0100 Subject: [PATCH] MQTT Static Typing (#12433) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * MQTT Typing * Tiny style change * Fixes I should've probably really sticked to limiting myself to static typing... * Small fix 😩 Ok, this seriously shouldn't have happened. --- homeassistant/components/mqtt/__init__.py | 207 +++++++++++++--------- homeassistant/helpers/typing.py | 1 + 2 files changed, 127 insertions(+), 81 deletions(-) diff --git a/homeassistant/components/mqtt/__init__.py b/homeassistant/components/mqtt/__init__.py index 8a5fdb5b86b..63662d2072d 100644 --- a/homeassistant/components/mqtt/__init__.py +++ b/homeassistant/components/mqtt/__init__.py @@ -5,9 +5,8 @@ For more details about this component, please refer to the documentation at https://home-assistant.io/components/mqtt/ """ import asyncio -from collections import namedtuple from itertools import groupby -from typing import Optional +from typing import Optional, Any, Union, Callable, List, cast # noqa: F401 from operator import attrgetter import logging import os @@ -16,15 +15,17 @@ import time import ssl import re import requests.certs +import attr import voluptuous as vol -from homeassistant.helpers.typing import HomeAssistantType -from homeassistant.core import callback +from homeassistant.helpers.typing import HomeAssistantType, ConfigType, \ + ServiceDataType +from homeassistant.core import callback, Event, ServiceCall from homeassistant.setup import async_prepare_setup_platform from homeassistant.exceptions import HomeAssistantError from homeassistant.loader import bind_hass -from homeassistant.helpers import template, ConfigType, config_validation as cv +from homeassistant.helpers import template, config_validation as cv from homeassistant.helpers.entity import Entity from homeassistant.util.async import ( run_coroutine_threadsafe, run_callback_threadsafe) @@ -89,7 +90,7 @@ ATTR_RETAIN = CONF_RETAIN MAX_RECONNECT_WAIT = 300 # seconds -def valid_subscribe_topic(value, invalid_chars='\0'): +def valid_subscribe_topic(value: Any, invalid_chars='\0') -> str: """Validate that we can subscribe using this MQTT topic.""" value = cv.string(value) if all(c not in value for c in invalid_chars): @@ -97,12 +98,12 @@ def valid_subscribe_topic(value, invalid_chars='\0'): raise vol.Invalid('Invalid MQTT topic name') -def valid_publish_topic(value): +def valid_publish_topic(value: Any) -> str: """Validate that we can publish using this MQTT topic.""" return valid_subscribe_topic(value, invalid_chars='#+\0') -def valid_discovery_topic(value): +def valid_discovery_topic(value: Any) -> str: """Validate a discovery topic.""" return valid_subscribe_topic(value, invalid_chars='#+\0/') @@ -185,7 +186,13 @@ MQTT_PUBLISH_SCHEMA = vol.Schema({ }, required=True) -def _build_publish_data(topic, qos, retain): +# pylint: disable=invalid-name +PublishPayloadType = Union[str, bytes, int, float, None] +SubscribePayloadType = Union[str, bytes] # Only bytes if encoding is None +MessageCallbackType = Callable[[str, SubscribePayloadType, int], None] + + +def _build_publish_data(topic: Any, qos: int, retain: bool) -> ServiceDataType: """Build the arguments for the publish service without the payload.""" data = {ATTR_TOPIC: topic} if qos is not None: @@ -196,14 +203,16 @@ def _build_publish_data(topic, qos, retain): @bind_hass -def publish(hass, topic, payload, qos=None, retain=None): +def publish(hass: HomeAssistantType, topic, payload, qos=None, + retain=None) -> None: """Publish message to an MQTT topic.""" hass.add_job(async_publish, hass, topic, payload, qos, retain) @callback @bind_hass -def async_publish(hass, topic, payload, qos=None, retain=None): +def async_publish(hass: HomeAssistantType, topic: Any, payload, qos=None, + retain=None) -> None: """Publish message to an MQTT topic.""" data = _build_publish_data(topic, qos, retain) data[ATTR_PAYLOAD] = payload @@ -211,7 +220,8 @@ def async_publish(hass, topic, payload, qos=None, retain=None): @bind_hass -def publish_template(hass, topic, payload_template, qos=None, retain=None): +def publish_template(hass: HomeAssistantType, topic, payload_template, + qos=None, retain=None) -> None: """Publish message to an MQTT topic using a template payload.""" data = _build_publish_data(topic, qos, retain) data[ATTR_PAYLOAD_TEMPLATE] = payload_template @@ -220,21 +230,23 @@ def publish_template(hass, topic, payload_template, qos=None, retain=None): @asyncio.coroutine @bind_hass -def async_subscribe(hass, topic, msg_callback, qos=DEFAULT_QOS, - encoding='utf-8'): +def async_subscribe(hass: HomeAssistantType, topic: str, + msg_callback: MessageCallbackType, + qos: int = DEFAULT_QOS, + encoding: str = 'utf-8'): """Subscribe to an MQTT topic. Call the return value to unsubscribe. """ - async_remove = \ - yield from hass.data[DATA_MQTT].async_subscribe(topic, msg_callback, - qos, encoding) + async_remove = yield from hass.data[DATA_MQTT].async_subscribe( + topic, msg_callback, qos, encoding) return async_remove @bind_hass -def subscribe(hass, topic, msg_callback, qos=DEFAULT_QOS, - encoding='utf-8'): +def subscribe(hass: HomeAssistantType, topic: str, + msg_callback: MessageCallbackType, qos: int = DEFAULT_QOS, + encoding: str = 'utf-8') -> Callable[[], None]: """Subscribe to an MQTT topic.""" async_remove = run_coroutine_threadsafe( async_subscribe(hass, topic, msg_callback, qos, encoding), hass.loop @@ -248,12 +260,13 @@ def subscribe(hass, topic, msg_callback, qos=DEFAULT_QOS, @asyncio.coroutine -def _async_setup_server(hass, config): +def _async_setup_server(hass: HomeAssistantType, + config: ConfigType): """Try to start embedded MQTT broker. This method is a coroutine. """ - conf = config.get(DOMAIN, {}) + conf = config.get(DOMAIN, {}) # type: ConfigType server = yield from async_prepare_setup_platform( hass, config, DOMAIN, 'server') @@ -265,26 +278,29 @@ def _async_setup_server(hass, config): success, broker_config = \ yield from server.async_start(hass, conf.get(CONF_EMBEDDED)) - return success and broker_config + if not success: + return None + return broker_config @asyncio.coroutine -def _async_setup_discovery(hass, config): +def _async_setup_discovery(hass: HomeAssistantType, + config: ConfigType): """Try to start the discovery of MQTT devices. This method is a coroutine. """ - conf = config.get(DOMAIN, {}) + conf = config.get(DOMAIN, {}) # type: ConfigType discovery = yield from async_prepare_setup_platform( hass, config, DOMAIN, 'discovery') if discovery is None: _LOGGER.error("Unable to load MQTT discovery") - return None + return False success = yield from discovery.async_start( - hass, conf[CONF_DISCOVERY_PREFIX], config) + hass, conf[CONF_DISCOVERY_PREFIX], config) # type: bool return success @@ -292,13 +308,14 @@ def _async_setup_discovery(hass, config): @asyncio.coroutine def async_setup(hass: HomeAssistantType, config: ConfigType): """Start the MQTT protocol service.""" - conf = config.get(DOMAIN) + conf = config.get(DOMAIN) # type: Optional[ConfigType] if conf is None: conf = CONFIG_SCHEMA({DOMAIN: {}})[DOMAIN] + conf = cast(ConfigType, conf) - client_id = conf.get(CONF_CLIENT_ID) - keepalive = conf.get(CONF_KEEPALIVE) + client_id = conf.get(CONF_CLIENT_ID) # type: Optional[str] + keepalive = conf.get(CONF_KEEPALIVE) # type: int # Only setup if embedded config passed in or no broker specified if CONF_EMBEDDED not in conf and CONF_BROKER in conf: @@ -307,16 +324,16 @@ def async_setup(hass: HomeAssistantType, config: ConfigType): broker_config = yield from _async_setup_server(hass, config) if CONF_BROKER in conf: - broker = conf[CONF_BROKER] - port = conf[CONF_PORT] - username = conf.get(CONF_USERNAME) - password = conf.get(CONF_PASSWORD) - certificate = conf.get(CONF_CERTIFICATE) - client_key = conf.get(CONF_CLIENT_KEY) - client_cert = conf.get(CONF_CLIENT_CERT) - tls_insecure = conf.get(CONF_TLS_INSECURE) - protocol = conf[CONF_PROTOCOL] - elif broker_config: + broker = conf[CONF_BROKER] # type: str + port = conf[CONF_PORT] # type: int + username = conf.get(CONF_USERNAME) # type: Optional[str] + password = conf.get(CONF_PASSWORD) # type: Optional[str] + certificate = conf.get(CONF_CERTIFICATE) # type: Optional[str] + client_key = conf.get(CONF_CLIENT_KEY) # type: Optional[str] + client_cert = conf.get(CONF_CLIENT_CERT) # type: Optional[str] + tls_insecure = conf.get(CONF_TLS_INSECURE) # type: Optional[bool] + protocol = conf[CONF_PROTOCOL] # type: str + elif broker_config is not None: # If no broker passed in, auto config to internal server broker, port, username, password, certificate, protocol = broker_config # Embedded broker doesn't have some ssl variables @@ -342,15 +359,15 @@ def async_setup(hass: HomeAssistantType, config: ConfigType): if certificate == 'auto': certificate = requests.certs.where() - will_message = None + will_message = None # type: Optional[Message] if conf.get(CONF_WILL_MESSAGE) is not None: will_message = Message(**conf.get(CONF_WILL_MESSAGE)) - birth_message = None + birth_message = None # type: Optional[Message] if conf.get(CONF_BIRTH_MESSAGE) is not None: birth_message = Message(**conf.get(CONF_BIRTH_MESSAGE)) # Be able to override versions other than TLSv1.0 under Python3.6 - conf_tls_version = conf.get(CONF_TLS_VERSION) + conf_tls_version = conf.get(CONF_TLS_VERSION) # type: str if conf_tls_version == '1.2': tls_version = ssl.PROTOCOL_TLSv1_2 elif conf_tls_version == '1.1': @@ -376,24 +393,24 @@ def async_setup(hass: HomeAssistantType, config: ConfigType): return False @asyncio.coroutine - def async_stop_mqtt(event): + def async_stop_mqtt(event: Event): """Stop MQTT component.""" yield from hass.data[DATA_MQTT].async_disconnect() hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, async_stop_mqtt) - success = yield from hass.data[DATA_MQTT].async_connect() + success = yield from hass.data[DATA_MQTT].async_connect() # type: bool if not success: return False @asyncio.coroutine - def async_publish_service(call): + def async_publish_service(call: ServiceCall): """Handle MQTT publish service calls.""" - msg_topic = call.data[ATTR_TOPIC] + msg_topic = call.data[ATTR_TOPIC] # type: str payload = call.data.get(ATTR_PAYLOAD) payload_template = call.data.get(ATTR_PAYLOAD_TEMPLATE) - qos = call.data[ATTR_QOS] - retain = call.data[ATTR_RETAIN] + qos = call.data[ATTR_QOS] # type: int + retain = call.data[ATTR_RETAIN] # type: bool if payload_template is not None: try: payload = \ @@ -418,21 +435,36 @@ def async_setup(hass: HomeAssistantType, config: ConfigType): return True -Subscription = namedtuple('Subscription', - ['topic', 'callback', 'qos', 'encoding']) -Subscription.__new__.__defaults__ = (0, 'utf-8') +@attr.s(slots=True, frozen=True) +class Subscription(object): + """Class to hold data about an active subscription.""" -Message = namedtuple('Message', ['topic', 'payload', 'qos', 'retain']) -Message.__new__.__defaults__ = (0, False) + topic = attr.ib(type=str) + callback = attr.ib(type=MessageCallbackType) + qos = attr.ib(type=int, default=0) + encoding = attr.ib(type=str, default='utf-8') + + +@attr.s(slots=True, frozen=True) +class Message(object): + """MQTT Message.""" + + topic = attr.ib(type=str) + payload = attr.ib(type=PublishPayloadType) + qos = attr.ib(type=int, default=0) + retain = attr.ib(type=bool, default=False) class MQTT(object): """Home Assistant MQTT client.""" - def __init__(self, hass, broker, port, client_id, keepalive, username, - password, certificate, client_key, client_cert, - tls_insecure, protocol, will_message: Optional[Message], - birth_message: Optional[Message], tls_version): + def __init__(self, hass: HomeAssistantType, broker: str, port: int, + client_id: Optional[str], keepalive: Optional[int], + username: Optional[str], password: Optional[str], + certificate: Optional[str], client_key: Optional[str], + client_cert: Optional[str], tls_insecure: Optional[bool], + protocol: Optional[str], will_message: Optional[Message], + birth_message: Optional[Message], tls_version) -> None: """Initialize Home Assistant MQTT client.""" import paho.mqtt.client as mqtt @@ -440,13 +472,13 @@ class MQTT(object): self.broker = broker self.port = port self.keepalive = keepalive - self.subscriptions = [] + self.subscriptions = [] # type: List[Subscription] self.birth_message = birth_message - self._mqttc = None + self._mqttc = None # type: mqtt.Client self._paho_lock = asyncio.Lock(loop=hass.loop) if protocol == PROTOCOL_31: - proto = mqtt.MQTTv31 + proto = mqtt.MQTTv31 # type: int else: proto = mqtt.MQTTv311 @@ -470,11 +502,12 @@ class MQTT(object): self._mqttc.on_disconnect = self._mqtt_on_disconnect self._mqttc.on_message = self._mqtt_on_message - if will_message: - self._mqttc.will_set(*will_message) + if will_message is not None: + self._mqttc.will_set(*attr.astuple(will_message)) @asyncio.coroutine - def async_publish(self, topic, payload, qos, retain): + def async_publish(self, topic: str, payload: PublishPayloadType, qos: int, + retain: bool): """Publish a MQTT message. This method must be run in the event loop and returns a coroutine. @@ -489,6 +522,7 @@ class MQTT(object): This method is a coroutine. """ + result = None # type: int result = yield from self.hass.async_add_job( self._mqttc.connect, self.broker, self.port, self.keepalive) @@ -500,6 +534,7 @@ class MQTT(object): return not result + @callback def async_disconnect(self): """Stop the MQTT client. @@ -513,7 +548,8 @@ class MQTT(object): return self.hass.async_add_job(stop) @asyncio.coroutine - def async_subscribe(self, topic, msg_callback, qos, encoding): + def async_subscribe(self, topic: str, msg_callback: MessageCallbackType, + qos: int, encoding: str): """Set up a subscription to a topic with the provided qos. This method is a coroutine. @@ -541,27 +577,31 @@ class MQTT(object): return async_remove @asyncio.coroutine - def _async_unsubscribe(self, topic): + def _async_unsubscribe(self, topic: str): """Unsubscribe from a topic. This method is a coroutine. """ with (yield from self._paho_lock): + result = None # type: int result, _ = yield from self.hass.async_add_job( self._mqttc.unsubscribe, topic) _raise_on_error(result) @asyncio.coroutine - def _async_perform_subscription(self, topic, qos): + def _async_perform_subscription(self, topic: str, + qos: int): """Perform a paho-mqtt subscription.""" _LOGGER.debug("Subscribing to %s", topic) with (yield from self._paho_lock): + result = None # type: int result, _ = yield from self.hass.async_add_job( self._mqttc.subscribe, topic, qos) _raise_on_error(result) - def _mqtt_on_connect(self, _mqttc, _userdata, _flags, result_code): + def _mqtt_on_connect(self, _mqttc, _userdata, _flags, + result_code: int) -> None: """On connect callback. Resubscribe to all topics we were subscribed to and publish birth @@ -584,21 +624,22 @@ class MQTT(object): self.hass.add_job(self._async_perform_subscription, topic, max_qos) if self.birth_message: - self.hass.add_job(self.async_publish(*self.birth_message)) + self.hass.add_job( + self.async_publish(*attr.astuple(self.birth_message))) - def _mqtt_on_message(self, _mqttc, _userdata, msg): + def _mqtt_on_message(self, _mqttc, _userdata, msg) -> None: """Message received callback.""" self.hass.add_job(self._mqtt_handle_message, msg) @callback - def _mqtt_handle_message(self, msg): + def _mqtt_handle_message(self, msg) -> None: _LOGGER.debug("Received message on %s: %s", msg.topic, msg.payload) for subscription in self.subscriptions: if not _match_topic(subscription.topic, msg.topic): continue - payload = msg.payload + payload = msg.payload # type: SubscribePayloadType if subscription.encoding is not None: try: payload = msg.payload.decode(subscription.encoding) @@ -612,7 +653,7 @@ class MQTT(object): self.hass.async_run_job(subscription.callback, msg.topic, payload, msg.qos) - def _mqtt_on_disconnect(self, _mqttc, _userdata, result_code): + def _mqtt_on_disconnect(self, _mqttc, _userdata, result_code: int) -> None: """Disconnected callback.""" # When disconnected because of calling disconnect() if result_code == 0: @@ -637,18 +678,18 @@ class MQTT(object): tries += 1 -def _raise_on_error(result): +def _raise_on_error(result_code: int) -> None: """Raise error if error result.""" - if result != 0: + if result_code != 0: import paho.mqtt.client as mqtt raise HomeAssistantError( - 'Error talking to MQTT: {}'.format(mqtt.error_string(result))) + 'Error talking to MQTT: {}'.format(mqtt.error_string(result_code))) -def _match_topic(subscription, topic): +def _match_topic(subscription: str, topic: str) -> bool: """Test if topic matches subscription.""" - reg_ex_parts = [] + reg_ex_parts = [] # type: List[str] suffix = "" if subscription.endswith('#'): subscription = subscription[:-2] @@ -670,22 +711,26 @@ def _match_topic(subscription, topic): class MqttAvailability(Entity): """Mixin used for platforms that report availability.""" - def __init__(self, availability_topic, qos, payload_available, - payload_not_available): + def __init__(self, availability_topic: Optional[str], qos: Optional[int], + payload_available: Optional[str], + payload_not_available: Optional[str]) -> None: """Initialize the availability mixin.""" self._availability_topic = availability_topic self._availability_qos = qos - self._available = availability_topic is None + self._available = availability_topic is None # type: bool self._payload_available = payload_available self._payload_not_available = payload_not_available + @asyncio.coroutine def async_added_to_hass(self): """Subscribe mqtt events. This method must be run in the event loop and returns a coroutine. """ @callback - def availability_message_received(topic, payload, qos): + def availability_message_received(topic: str, + payload: SubscribePayloadType, + qos: int) -> None: """Handle a new received MQTT availability message.""" if payload == self._payload_available: self._available = True diff --git a/homeassistant/helpers/typing.py b/homeassistant/helpers/typing.py index d0feab414da..3919d896fd1 100644 --- a/homeassistant/helpers/typing.py +++ b/homeassistant/helpers/typing.py @@ -8,6 +8,7 @@ import homeassistant.core GPSType = Tuple[float, float] ConfigType = Dict[str, Any] HomeAssistantType = homeassistant.core.HomeAssistant +ServiceDataType = Dict[str, Any] # Custom type for recorder Queries QueryType = Any