MQTT Static Typing (#12433)

* MQTT Typing

* Tiny style change

* Fixes

I should've probably really sticked to limiting myself to static typing...

* Small fix 😩

Ok, this seriously shouldn't have happened.
This commit is contained in:
Otto Winter 2018-02-28 22:59:14 +01:00 committed by GitHub
parent f7e9215f5e
commit c1aaef28a9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 127 additions and 81 deletions

View File

@ -5,9 +5,8 @@ For more details about this component, please refer to the documentation at
https://home-assistant.io/components/mqtt/
"""
import asyncio
from collections import namedtuple
from itertools import groupby
from typing import Optional
from typing import Optional, Any, Union, Callable, List, cast # noqa: F401
from operator import attrgetter
import logging
import os
@ -16,15 +15,17 @@ import time
import ssl
import re
import requests.certs
import attr
import voluptuous as vol
from homeassistant.helpers.typing import HomeAssistantType
from homeassistant.core import callback
from homeassistant.helpers.typing import HomeAssistantType, ConfigType, \
ServiceDataType
from homeassistant.core import callback, Event, ServiceCall
from homeassistant.setup import async_prepare_setup_platform
from homeassistant.exceptions import HomeAssistantError
from homeassistant.loader import bind_hass
from homeassistant.helpers import template, ConfigType, config_validation as cv
from homeassistant.helpers import template, config_validation as cv
from homeassistant.helpers.entity import Entity
from homeassistant.util.async import (
run_coroutine_threadsafe, run_callback_threadsafe)
@ -89,7 +90,7 @@ ATTR_RETAIN = CONF_RETAIN
MAX_RECONNECT_WAIT = 300 # seconds
def valid_subscribe_topic(value, invalid_chars='\0'):
def valid_subscribe_topic(value: Any, invalid_chars='\0') -> str:
"""Validate that we can subscribe using this MQTT topic."""
value = cv.string(value)
if all(c not in value for c in invalid_chars):
@ -97,12 +98,12 @@ def valid_subscribe_topic(value, invalid_chars='\0'):
raise vol.Invalid('Invalid MQTT topic name')
def valid_publish_topic(value):
def valid_publish_topic(value: Any) -> str:
"""Validate that we can publish using this MQTT topic."""
return valid_subscribe_topic(value, invalid_chars='#+\0')
def valid_discovery_topic(value):
def valid_discovery_topic(value: Any) -> str:
"""Validate a discovery topic."""
return valid_subscribe_topic(value, invalid_chars='#+\0/')
@ -185,7 +186,13 @@ MQTT_PUBLISH_SCHEMA = vol.Schema({
}, required=True)
def _build_publish_data(topic, qos, retain):
# pylint: disable=invalid-name
PublishPayloadType = Union[str, bytes, int, float, None]
SubscribePayloadType = Union[str, bytes] # Only bytes if encoding is None
MessageCallbackType = Callable[[str, SubscribePayloadType, int], None]
def _build_publish_data(topic: Any, qos: int, retain: bool) -> ServiceDataType:
"""Build the arguments for the publish service without the payload."""
data = {ATTR_TOPIC: topic}
if qos is not None:
@ -196,14 +203,16 @@ def _build_publish_data(topic, qos, retain):
@bind_hass
def publish(hass, topic, payload, qos=None, retain=None):
def publish(hass: HomeAssistantType, topic, payload, qos=None,
retain=None) -> None:
"""Publish message to an MQTT topic."""
hass.add_job(async_publish, hass, topic, payload, qos, retain)
@callback
@bind_hass
def async_publish(hass, topic, payload, qos=None, retain=None):
def async_publish(hass: HomeAssistantType, topic: Any, payload, qos=None,
retain=None) -> None:
"""Publish message to an MQTT topic."""
data = _build_publish_data(topic, qos, retain)
data[ATTR_PAYLOAD] = payload
@ -211,7 +220,8 @@ def async_publish(hass, topic, payload, qos=None, retain=None):
@bind_hass
def publish_template(hass, topic, payload_template, qos=None, retain=None):
def publish_template(hass: HomeAssistantType, topic, payload_template,
qos=None, retain=None) -> None:
"""Publish message to an MQTT topic using a template payload."""
data = _build_publish_data(topic, qos, retain)
data[ATTR_PAYLOAD_TEMPLATE] = payload_template
@ -220,21 +230,23 @@ def publish_template(hass, topic, payload_template, qos=None, retain=None):
@asyncio.coroutine
@bind_hass
def async_subscribe(hass, topic, msg_callback, qos=DEFAULT_QOS,
encoding='utf-8'):
def async_subscribe(hass: HomeAssistantType, topic: str,
msg_callback: MessageCallbackType,
qos: int = DEFAULT_QOS,
encoding: str = 'utf-8'):
"""Subscribe to an MQTT topic.
Call the return value to unsubscribe.
"""
async_remove = \
yield from hass.data[DATA_MQTT].async_subscribe(topic, msg_callback,
qos, encoding)
async_remove = yield from hass.data[DATA_MQTT].async_subscribe(
topic, msg_callback, qos, encoding)
return async_remove
@bind_hass
def subscribe(hass, topic, msg_callback, qos=DEFAULT_QOS,
encoding='utf-8'):
def subscribe(hass: HomeAssistantType, topic: str,
msg_callback: MessageCallbackType, qos: int = DEFAULT_QOS,
encoding: str = 'utf-8') -> Callable[[], None]:
"""Subscribe to an MQTT topic."""
async_remove = run_coroutine_threadsafe(
async_subscribe(hass, topic, msg_callback, qos, encoding), hass.loop
@ -248,12 +260,13 @@ def subscribe(hass, topic, msg_callback, qos=DEFAULT_QOS,
@asyncio.coroutine
def _async_setup_server(hass, config):
def _async_setup_server(hass: HomeAssistantType,
config: ConfigType):
"""Try to start embedded MQTT broker.
This method is a coroutine.
"""
conf = config.get(DOMAIN, {})
conf = config.get(DOMAIN, {}) # type: ConfigType
server = yield from async_prepare_setup_platform(
hass, config, DOMAIN, 'server')
@ -265,26 +278,29 @@ def _async_setup_server(hass, config):
success, broker_config = \
yield from server.async_start(hass, conf.get(CONF_EMBEDDED))
return success and broker_config
if not success:
return None
return broker_config
@asyncio.coroutine
def _async_setup_discovery(hass, config):
def _async_setup_discovery(hass: HomeAssistantType,
config: ConfigType):
"""Try to start the discovery of MQTT devices.
This method is a coroutine.
"""
conf = config.get(DOMAIN, {})
conf = config.get(DOMAIN, {}) # type: ConfigType
discovery = yield from async_prepare_setup_platform(
hass, config, DOMAIN, 'discovery')
if discovery is None:
_LOGGER.error("Unable to load MQTT discovery")
return None
return False
success = yield from discovery.async_start(
hass, conf[CONF_DISCOVERY_PREFIX], config)
hass, conf[CONF_DISCOVERY_PREFIX], config) # type: bool
return success
@ -292,13 +308,14 @@ def _async_setup_discovery(hass, config):
@asyncio.coroutine
def async_setup(hass: HomeAssistantType, config: ConfigType):
"""Start the MQTT protocol service."""
conf = config.get(DOMAIN)
conf = config.get(DOMAIN) # type: Optional[ConfigType]
if conf is None:
conf = CONFIG_SCHEMA({DOMAIN: {}})[DOMAIN]
conf = cast(ConfigType, conf)
client_id = conf.get(CONF_CLIENT_ID)
keepalive = conf.get(CONF_KEEPALIVE)
client_id = conf.get(CONF_CLIENT_ID) # type: Optional[str]
keepalive = conf.get(CONF_KEEPALIVE) # type: int
# Only setup if embedded config passed in or no broker specified
if CONF_EMBEDDED not in conf and CONF_BROKER in conf:
@ -307,16 +324,16 @@ def async_setup(hass: HomeAssistantType, config: ConfigType):
broker_config = yield from _async_setup_server(hass, config)
if CONF_BROKER in conf:
broker = conf[CONF_BROKER]
port = conf[CONF_PORT]
username = conf.get(CONF_USERNAME)
password = conf.get(CONF_PASSWORD)
certificate = conf.get(CONF_CERTIFICATE)
client_key = conf.get(CONF_CLIENT_KEY)
client_cert = conf.get(CONF_CLIENT_CERT)
tls_insecure = conf.get(CONF_TLS_INSECURE)
protocol = conf[CONF_PROTOCOL]
elif broker_config:
broker = conf[CONF_BROKER] # type: str
port = conf[CONF_PORT] # type: int
username = conf.get(CONF_USERNAME) # type: Optional[str]
password = conf.get(CONF_PASSWORD) # type: Optional[str]
certificate = conf.get(CONF_CERTIFICATE) # type: Optional[str]
client_key = conf.get(CONF_CLIENT_KEY) # type: Optional[str]
client_cert = conf.get(CONF_CLIENT_CERT) # type: Optional[str]
tls_insecure = conf.get(CONF_TLS_INSECURE) # type: Optional[bool]
protocol = conf[CONF_PROTOCOL] # type: str
elif broker_config is not None:
# If no broker passed in, auto config to internal server
broker, port, username, password, certificate, protocol = broker_config
# Embedded broker doesn't have some ssl variables
@ -342,15 +359,15 @@ def async_setup(hass: HomeAssistantType, config: ConfigType):
if certificate == 'auto':
certificate = requests.certs.where()
will_message = None
will_message = None # type: Optional[Message]
if conf.get(CONF_WILL_MESSAGE) is not None:
will_message = Message(**conf.get(CONF_WILL_MESSAGE))
birth_message = None
birth_message = None # type: Optional[Message]
if conf.get(CONF_BIRTH_MESSAGE) is not None:
birth_message = Message(**conf.get(CONF_BIRTH_MESSAGE))
# Be able to override versions other than TLSv1.0 under Python3.6
conf_tls_version = conf.get(CONF_TLS_VERSION)
conf_tls_version = conf.get(CONF_TLS_VERSION) # type: str
if conf_tls_version == '1.2':
tls_version = ssl.PROTOCOL_TLSv1_2
elif conf_tls_version == '1.1':
@ -376,24 +393,24 @@ def async_setup(hass: HomeAssistantType, config: ConfigType):
return False
@asyncio.coroutine
def async_stop_mqtt(event):
def async_stop_mqtt(event: Event):
"""Stop MQTT component."""
yield from hass.data[DATA_MQTT].async_disconnect()
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, async_stop_mqtt)
success = yield from hass.data[DATA_MQTT].async_connect()
success = yield from hass.data[DATA_MQTT].async_connect() # type: bool
if not success:
return False
@asyncio.coroutine
def async_publish_service(call):
def async_publish_service(call: ServiceCall):
"""Handle MQTT publish service calls."""
msg_topic = call.data[ATTR_TOPIC]
msg_topic = call.data[ATTR_TOPIC] # type: str
payload = call.data.get(ATTR_PAYLOAD)
payload_template = call.data.get(ATTR_PAYLOAD_TEMPLATE)
qos = call.data[ATTR_QOS]
retain = call.data[ATTR_RETAIN]
qos = call.data[ATTR_QOS] # type: int
retain = call.data[ATTR_RETAIN] # type: bool
if payload_template is not None:
try:
payload = \
@ -418,21 +435,36 @@ def async_setup(hass: HomeAssistantType, config: ConfigType):
return True
Subscription = namedtuple('Subscription',
['topic', 'callback', 'qos', 'encoding'])
Subscription.__new__.__defaults__ = (0, 'utf-8')
@attr.s(slots=True, frozen=True)
class Subscription(object):
"""Class to hold data about an active subscription."""
Message = namedtuple('Message', ['topic', 'payload', 'qos', 'retain'])
Message.__new__.__defaults__ = (0, False)
topic = attr.ib(type=str)
callback = attr.ib(type=MessageCallbackType)
qos = attr.ib(type=int, default=0)
encoding = attr.ib(type=str, default='utf-8')
@attr.s(slots=True, frozen=True)
class Message(object):
"""MQTT Message."""
topic = attr.ib(type=str)
payload = attr.ib(type=PublishPayloadType)
qos = attr.ib(type=int, default=0)
retain = attr.ib(type=bool, default=False)
class MQTT(object):
"""Home Assistant MQTT client."""
def __init__(self, hass, broker, port, client_id, keepalive, username,
password, certificate, client_key, client_cert,
tls_insecure, protocol, will_message: Optional[Message],
birth_message: Optional[Message], tls_version):
def __init__(self, hass: HomeAssistantType, broker: str, port: int,
client_id: Optional[str], keepalive: Optional[int],
username: Optional[str], password: Optional[str],
certificate: Optional[str], client_key: Optional[str],
client_cert: Optional[str], tls_insecure: Optional[bool],
protocol: Optional[str], will_message: Optional[Message],
birth_message: Optional[Message], tls_version) -> None:
"""Initialize Home Assistant MQTT client."""
import paho.mqtt.client as mqtt
@ -440,13 +472,13 @@ class MQTT(object):
self.broker = broker
self.port = port
self.keepalive = keepalive
self.subscriptions = []
self.subscriptions = [] # type: List[Subscription]
self.birth_message = birth_message
self._mqttc = None
self._mqttc = None # type: mqtt.Client
self._paho_lock = asyncio.Lock(loop=hass.loop)
if protocol == PROTOCOL_31:
proto = mqtt.MQTTv31
proto = mqtt.MQTTv31 # type: int
else:
proto = mqtt.MQTTv311
@ -470,11 +502,12 @@ class MQTT(object):
self._mqttc.on_disconnect = self._mqtt_on_disconnect
self._mqttc.on_message = self._mqtt_on_message
if will_message:
self._mqttc.will_set(*will_message)
if will_message is not None:
self._mqttc.will_set(*attr.astuple(will_message))
@asyncio.coroutine
def async_publish(self, topic, payload, qos, retain):
def async_publish(self, topic: str, payload: PublishPayloadType, qos: int,
retain: bool):
"""Publish a MQTT message.
This method must be run in the event loop and returns a coroutine.
@ -489,6 +522,7 @@ class MQTT(object):
This method is a coroutine.
"""
result = None # type: int
result = yield from self.hass.async_add_job(
self._mqttc.connect, self.broker, self.port, self.keepalive)
@ -500,6 +534,7 @@ class MQTT(object):
return not result
@callback
def async_disconnect(self):
"""Stop the MQTT client.
@ -513,7 +548,8 @@ class MQTT(object):
return self.hass.async_add_job(stop)
@asyncio.coroutine
def async_subscribe(self, topic, msg_callback, qos, encoding):
def async_subscribe(self, topic: str, msg_callback: MessageCallbackType,
qos: int, encoding: str):
"""Set up a subscription to a topic with the provided qos.
This method is a coroutine.
@ -541,27 +577,31 @@ class MQTT(object):
return async_remove
@asyncio.coroutine
def _async_unsubscribe(self, topic):
def _async_unsubscribe(self, topic: str):
"""Unsubscribe from a topic.
This method is a coroutine.
"""
with (yield from self._paho_lock):
result = None # type: int
result, _ = yield from self.hass.async_add_job(
self._mqttc.unsubscribe, topic)
_raise_on_error(result)
@asyncio.coroutine
def _async_perform_subscription(self, topic, qos):
def _async_perform_subscription(self, topic: str,
qos: int):
"""Perform a paho-mqtt subscription."""
_LOGGER.debug("Subscribing to %s", topic)
with (yield from self._paho_lock):
result = None # type: int
result, _ = yield from self.hass.async_add_job(
self._mqttc.subscribe, topic, qos)
_raise_on_error(result)
def _mqtt_on_connect(self, _mqttc, _userdata, _flags, result_code):
def _mqtt_on_connect(self, _mqttc, _userdata, _flags,
result_code: int) -> None:
"""On connect callback.
Resubscribe to all topics we were subscribed to and publish birth
@ -584,21 +624,22 @@ class MQTT(object):
self.hass.add_job(self._async_perform_subscription, topic, max_qos)
if self.birth_message:
self.hass.add_job(self.async_publish(*self.birth_message))
self.hass.add_job(
self.async_publish(*attr.astuple(self.birth_message)))
def _mqtt_on_message(self, _mqttc, _userdata, msg):
def _mqtt_on_message(self, _mqttc, _userdata, msg) -> None:
"""Message received callback."""
self.hass.add_job(self._mqtt_handle_message, msg)
@callback
def _mqtt_handle_message(self, msg):
def _mqtt_handle_message(self, msg) -> None:
_LOGGER.debug("Received message on %s: %s", msg.topic, msg.payload)
for subscription in self.subscriptions:
if not _match_topic(subscription.topic, msg.topic):
continue
payload = msg.payload
payload = msg.payload # type: SubscribePayloadType
if subscription.encoding is not None:
try:
payload = msg.payload.decode(subscription.encoding)
@ -612,7 +653,7 @@ class MQTT(object):
self.hass.async_run_job(subscription.callback,
msg.topic, payload, msg.qos)
def _mqtt_on_disconnect(self, _mqttc, _userdata, result_code):
def _mqtt_on_disconnect(self, _mqttc, _userdata, result_code: int) -> None:
"""Disconnected callback."""
# When disconnected because of calling disconnect()
if result_code == 0:
@ -637,18 +678,18 @@ class MQTT(object):
tries += 1
def _raise_on_error(result):
def _raise_on_error(result_code: int) -> None:
"""Raise error if error result."""
if result != 0:
if result_code != 0:
import paho.mqtt.client as mqtt
raise HomeAssistantError(
'Error talking to MQTT: {}'.format(mqtt.error_string(result)))
'Error talking to MQTT: {}'.format(mqtt.error_string(result_code)))
def _match_topic(subscription, topic):
def _match_topic(subscription: str, topic: str) -> bool:
"""Test if topic matches subscription."""
reg_ex_parts = []
reg_ex_parts = [] # type: List[str]
suffix = ""
if subscription.endswith('#'):
subscription = subscription[:-2]
@ -670,22 +711,26 @@ def _match_topic(subscription, topic):
class MqttAvailability(Entity):
"""Mixin used for platforms that report availability."""
def __init__(self, availability_topic, qos, payload_available,
payload_not_available):
def __init__(self, availability_topic: Optional[str], qos: Optional[int],
payload_available: Optional[str],
payload_not_available: Optional[str]) -> None:
"""Initialize the availability mixin."""
self._availability_topic = availability_topic
self._availability_qos = qos
self._available = availability_topic is None
self._available = availability_topic is None # type: bool
self._payload_available = payload_available
self._payload_not_available = payload_not_available
@asyncio.coroutine
def async_added_to_hass(self):
"""Subscribe mqtt events.
This method must be run in the event loop and returns a coroutine.
"""
@callback
def availability_message_received(topic, payload, qos):
def availability_message_received(topic: str,
payload: SubscribePayloadType,
qos: int) -> None:
"""Handle a new received MQTT availability message."""
if payload == self._payload_available:
self._available = True

View File

@ -8,6 +8,7 @@ import homeassistant.core
GPSType = Tuple[float, float]
ConfigType = Dict[str, Any]
HomeAssistantType = homeassistant.core.HomeAssistant
ServiceDataType = Dict[str, Any]
# Custom type for recorder Queries
QueryType = Any