mirror of
https://github.com/home-assistant/core.git
synced 2025-07-19 19:27:45 +00:00
MQTT Static Typing (#12433)
* MQTT Typing
* Tiny style change
* Fixes
I should've probably really sticked to limiting myself to static typing...
* Small fix 😩
Ok, this seriously shouldn't have happened.
This commit is contained in:
parent
f7e9215f5e
commit
c1aaef28a9
@ -5,9 +5,8 @@ For more details about this component, please refer to the documentation at
|
|||||||
https://home-assistant.io/components/mqtt/
|
https://home-assistant.io/components/mqtt/
|
||||||
"""
|
"""
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections import namedtuple
|
|
||||||
from itertools import groupby
|
from itertools import groupby
|
||||||
from typing import Optional
|
from typing import Optional, Any, Union, Callable, List, cast # noqa: F401
|
||||||
from operator import attrgetter
|
from operator import attrgetter
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@ -16,15 +15,17 @@ import time
|
|||||||
import ssl
|
import ssl
|
||||||
import re
|
import re
|
||||||
import requests.certs
|
import requests.certs
|
||||||
|
import attr
|
||||||
|
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.helpers.typing import HomeAssistantType
|
from homeassistant.helpers.typing import HomeAssistantType, ConfigType, \
|
||||||
from homeassistant.core import callback
|
ServiceDataType
|
||||||
|
from homeassistant.core import callback, Event, ServiceCall
|
||||||
from homeassistant.setup import async_prepare_setup_platform
|
from homeassistant.setup import async_prepare_setup_platform
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.loader import bind_hass
|
from homeassistant.loader import bind_hass
|
||||||
from homeassistant.helpers import template, ConfigType, config_validation as cv
|
from homeassistant.helpers import template, config_validation as cv
|
||||||
from homeassistant.helpers.entity import Entity
|
from homeassistant.helpers.entity import Entity
|
||||||
from homeassistant.util.async import (
|
from homeassistant.util.async import (
|
||||||
run_coroutine_threadsafe, run_callback_threadsafe)
|
run_coroutine_threadsafe, run_callback_threadsafe)
|
||||||
@ -89,7 +90,7 @@ ATTR_RETAIN = CONF_RETAIN
|
|||||||
MAX_RECONNECT_WAIT = 300 # seconds
|
MAX_RECONNECT_WAIT = 300 # seconds
|
||||||
|
|
||||||
|
|
||||||
def valid_subscribe_topic(value, invalid_chars='\0'):
|
def valid_subscribe_topic(value: Any, invalid_chars='\0') -> str:
|
||||||
"""Validate that we can subscribe using this MQTT topic."""
|
"""Validate that we can subscribe using this MQTT topic."""
|
||||||
value = cv.string(value)
|
value = cv.string(value)
|
||||||
if all(c not in value for c in invalid_chars):
|
if all(c not in value for c in invalid_chars):
|
||||||
@ -97,12 +98,12 @@ def valid_subscribe_topic(value, invalid_chars='\0'):
|
|||||||
raise vol.Invalid('Invalid MQTT topic name')
|
raise vol.Invalid('Invalid MQTT topic name')
|
||||||
|
|
||||||
|
|
||||||
def valid_publish_topic(value):
|
def valid_publish_topic(value: Any) -> str:
|
||||||
"""Validate that we can publish using this MQTT topic."""
|
"""Validate that we can publish using this MQTT topic."""
|
||||||
return valid_subscribe_topic(value, invalid_chars='#+\0')
|
return valid_subscribe_topic(value, invalid_chars='#+\0')
|
||||||
|
|
||||||
|
|
||||||
def valid_discovery_topic(value):
|
def valid_discovery_topic(value: Any) -> str:
|
||||||
"""Validate a discovery topic."""
|
"""Validate a discovery topic."""
|
||||||
return valid_subscribe_topic(value, invalid_chars='#+\0/')
|
return valid_subscribe_topic(value, invalid_chars='#+\0/')
|
||||||
|
|
||||||
@ -185,7 +186,13 @@ MQTT_PUBLISH_SCHEMA = vol.Schema({
|
|||||||
}, required=True)
|
}, required=True)
|
||||||
|
|
||||||
|
|
||||||
def _build_publish_data(topic, qos, retain):
|
# pylint: disable=invalid-name
|
||||||
|
PublishPayloadType = Union[str, bytes, int, float, None]
|
||||||
|
SubscribePayloadType = Union[str, bytes] # Only bytes if encoding is None
|
||||||
|
MessageCallbackType = Callable[[str, SubscribePayloadType, int], None]
|
||||||
|
|
||||||
|
|
||||||
|
def _build_publish_data(topic: Any, qos: int, retain: bool) -> ServiceDataType:
|
||||||
"""Build the arguments for the publish service without the payload."""
|
"""Build the arguments for the publish service without the payload."""
|
||||||
data = {ATTR_TOPIC: topic}
|
data = {ATTR_TOPIC: topic}
|
||||||
if qos is not None:
|
if qos is not None:
|
||||||
@ -196,14 +203,16 @@ def _build_publish_data(topic, qos, retain):
|
|||||||
|
|
||||||
|
|
||||||
@bind_hass
|
@bind_hass
|
||||||
def publish(hass, topic, payload, qos=None, retain=None):
|
def publish(hass: HomeAssistantType, topic, payload, qos=None,
|
||||||
|
retain=None) -> None:
|
||||||
"""Publish message to an MQTT topic."""
|
"""Publish message to an MQTT topic."""
|
||||||
hass.add_job(async_publish, hass, topic, payload, qos, retain)
|
hass.add_job(async_publish, hass, topic, payload, qos, retain)
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
@bind_hass
|
@bind_hass
|
||||||
def async_publish(hass, topic, payload, qos=None, retain=None):
|
def async_publish(hass: HomeAssistantType, topic: Any, payload, qos=None,
|
||||||
|
retain=None) -> None:
|
||||||
"""Publish message to an MQTT topic."""
|
"""Publish message to an MQTT topic."""
|
||||||
data = _build_publish_data(topic, qos, retain)
|
data = _build_publish_data(topic, qos, retain)
|
||||||
data[ATTR_PAYLOAD] = payload
|
data[ATTR_PAYLOAD] = payload
|
||||||
@ -211,7 +220,8 @@ def async_publish(hass, topic, payload, qos=None, retain=None):
|
|||||||
|
|
||||||
|
|
||||||
@bind_hass
|
@bind_hass
|
||||||
def publish_template(hass, topic, payload_template, qos=None, retain=None):
|
def publish_template(hass: HomeAssistantType, topic, payload_template,
|
||||||
|
qos=None, retain=None) -> None:
|
||||||
"""Publish message to an MQTT topic using a template payload."""
|
"""Publish message to an MQTT topic using a template payload."""
|
||||||
data = _build_publish_data(topic, qos, retain)
|
data = _build_publish_data(topic, qos, retain)
|
||||||
data[ATTR_PAYLOAD_TEMPLATE] = payload_template
|
data[ATTR_PAYLOAD_TEMPLATE] = payload_template
|
||||||
@ -220,21 +230,23 @@ def publish_template(hass, topic, payload_template, qos=None, retain=None):
|
|||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
@bind_hass
|
@bind_hass
|
||||||
def async_subscribe(hass, topic, msg_callback, qos=DEFAULT_QOS,
|
def async_subscribe(hass: HomeAssistantType, topic: str,
|
||||||
encoding='utf-8'):
|
msg_callback: MessageCallbackType,
|
||||||
|
qos: int = DEFAULT_QOS,
|
||||||
|
encoding: str = 'utf-8'):
|
||||||
"""Subscribe to an MQTT topic.
|
"""Subscribe to an MQTT topic.
|
||||||
|
|
||||||
Call the return value to unsubscribe.
|
Call the return value to unsubscribe.
|
||||||
"""
|
"""
|
||||||
async_remove = \
|
async_remove = yield from hass.data[DATA_MQTT].async_subscribe(
|
||||||
yield from hass.data[DATA_MQTT].async_subscribe(topic, msg_callback,
|
topic, msg_callback, qos, encoding)
|
||||||
qos, encoding)
|
|
||||||
return async_remove
|
return async_remove
|
||||||
|
|
||||||
|
|
||||||
@bind_hass
|
@bind_hass
|
||||||
def subscribe(hass, topic, msg_callback, qos=DEFAULT_QOS,
|
def subscribe(hass: HomeAssistantType, topic: str,
|
||||||
encoding='utf-8'):
|
msg_callback: MessageCallbackType, qos: int = DEFAULT_QOS,
|
||||||
|
encoding: str = 'utf-8') -> Callable[[], None]:
|
||||||
"""Subscribe to an MQTT topic."""
|
"""Subscribe to an MQTT topic."""
|
||||||
async_remove = run_coroutine_threadsafe(
|
async_remove = run_coroutine_threadsafe(
|
||||||
async_subscribe(hass, topic, msg_callback, qos, encoding), hass.loop
|
async_subscribe(hass, topic, msg_callback, qos, encoding), hass.loop
|
||||||
@ -248,12 +260,13 @@ def subscribe(hass, topic, msg_callback, qos=DEFAULT_QOS,
|
|||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def _async_setup_server(hass, config):
|
def _async_setup_server(hass: HomeAssistantType,
|
||||||
|
config: ConfigType):
|
||||||
"""Try to start embedded MQTT broker.
|
"""Try to start embedded MQTT broker.
|
||||||
|
|
||||||
This method is a coroutine.
|
This method is a coroutine.
|
||||||
"""
|
"""
|
||||||
conf = config.get(DOMAIN, {})
|
conf = config.get(DOMAIN, {}) # type: ConfigType
|
||||||
|
|
||||||
server = yield from async_prepare_setup_platform(
|
server = yield from async_prepare_setup_platform(
|
||||||
hass, config, DOMAIN, 'server')
|
hass, config, DOMAIN, 'server')
|
||||||
@ -265,26 +278,29 @@ def _async_setup_server(hass, config):
|
|||||||
success, broker_config = \
|
success, broker_config = \
|
||||||
yield from server.async_start(hass, conf.get(CONF_EMBEDDED))
|
yield from server.async_start(hass, conf.get(CONF_EMBEDDED))
|
||||||
|
|
||||||
return success and broker_config
|
if not success:
|
||||||
|
return None
|
||||||
|
return broker_config
|
||||||
|
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def _async_setup_discovery(hass, config):
|
def _async_setup_discovery(hass: HomeAssistantType,
|
||||||
|
config: ConfigType):
|
||||||
"""Try to start the discovery of MQTT devices.
|
"""Try to start the discovery of MQTT devices.
|
||||||
|
|
||||||
This method is a coroutine.
|
This method is a coroutine.
|
||||||
"""
|
"""
|
||||||
conf = config.get(DOMAIN, {})
|
conf = config.get(DOMAIN, {}) # type: ConfigType
|
||||||
|
|
||||||
discovery = yield from async_prepare_setup_platform(
|
discovery = yield from async_prepare_setup_platform(
|
||||||
hass, config, DOMAIN, 'discovery')
|
hass, config, DOMAIN, 'discovery')
|
||||||
|
|
||||||
if discovery is None:
|
if discovery is None:
|
||||||
_LOGGER.error("Unable to load MQTT discovery")
|
_LOGGER.error("Unable to load MQTT discovery")
|
||||||
return None
|
return False
|
||||||
|
|
||||||
success = yield from discovery.async_start(
|
success = yield from discovery.async_start(
|
||||||
hass, conf[CONF_DISCOVERY_PREFIX], config)
|
hass, conf[CONF_DISCOVERY_PREFIX], config) # type: bool
|
||||||
|
|
||||||
return success
|
return success
|
||||||
|
|
||||||
@ -292,13 +308,14 @@ def _async_setup_discovery(hass, config):
|
|||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def async_setup(hass: HomeAssistantType, config: ConfigType):
|
def async_setup(hass: HomeAssistantType, config: ConfigType):
|
||||||
"""Start the MQTT protocol service."""
|
"""Start the MQTT protocol service."""
|
||||||
conf = config.get(DOMAIN)
|
conf = config.get(DOMAIN) # type: Optional[ConfigType]
|
||||||
|
|
||||||
if conf is None:
|
if conf is None:
|
||||||
conf = CONFIG_SCHEMA({DOMAIN: {}})[DOMAIN]
|
conf = CONFIG_SCHEMA({DOMAIN: {}})[DOMAIN]
|
||||||
|
conf = cast(ConfigType, conf)
|
||||||
|
|
||||||
client_id = conf.get(CONF_CLIENT_ID)
|
client_id = conf.get(CONF_CLIENT_ID) # type: Optional[str]
|
||||||
keepalive = conf.get(CONF_KEEPALIVE)
|
keepalive = conf.get(CONF_KEEPALIVE) # type: int
|
||||||
|
|
||||||
# Only setup if embedded config passed in or no broker specified
|
# Only setup if embedded config passed in or no broker specified
|
||||||
if CONF_EMBEDDED not in conf and CONF_BROKER in conf:
|
if CONF_EMBEDDED not in conf and CONF_BROKER in conf:
|
||||||
@ -307,16 +324,16 @@ def async_setup(hass: HomeAssistantType, config: ConfigType):
|
|||||||
broker_config = yield from _async_setup_server(hass, config)
|
broker_config = yield from _async_setup_server(hass, config)
|
||||||
|
|
||||||
if CONF_BROKER in conf:
|
if CONF_BROKER in conf:
|
||||||
broker = conf[CONF_BROKER]
|
broker = conf[CONF_BROKER] # type: str
|
||||||
port = conf[CONF_PORT]
|
port = conf[CONF_PORT] # type: int
|
||||||
username = conf.get(CONF_USERNAME)
|
username = conf.get(CONF_USERNAME) # type: Optional[str]
|
||||||
password = conf.get(CONF_PASSWORD)
|
password = conf.get(CONF_PASSWORD) # type: Optional[str]
|
||||||
certificate = conf.get(CONF_CERTIFICATE)
|
certificate = conf.get(CONF_CERTIFICATE) # type: Optional[str]
|
||||||
client_key = conf.get(CONF_CLIENT_KEY)
|
client_key = conf.get(CONF_CLIENT_KEY) # type: Optional[str]
|
||||||
client_cert = conf.get(CONF_CLIENT_CERT)
|
client_cert = conf.get(CONF_CLIENT_CERT) # type: Optional[str]
|
||||||
tls_insecure = conf.get(CONF_TLS_INSECURE)
|
tls_insecure = conf.get(CONF_TLS_INSECURE) # type: Optional[bool]
|
||||||
protocol = conf[CONF_PROTOCOL]
|
protocol = conf[CONF_PROTOCOL] # type: str
|
||||||
elif broker_config:
|
elif broker_config is not None:
|
||||||
# If no broker passed in, auto config to internal server
|
# If no broker passed in, auto config to internal server
|
||||||
broker, port, username, password, certificate, protocol = broker_config
|
broker, port, username, password, certificate, protocol = broker_config
|
||||||
# Embedded broker doesn't have some ssl variables
|
# Embedded broker doesn't have some ssl variables
|
||||||
@ -342,15 +359,15 @@ def async_setup(hass: HomeAssistantType, config: ConfigType):
|
|||||||
if certificate == 'auto':
|
if certificate == 'auto':
|
||||||
certificate = requests.certs.where()
|
certificate = requests.certs.where()
|
||||||
|
|
||||||
will_message = None
|
will_message = None # type: Optional[Message]
|
||||||
if conf.get(CONF_WILL_MESSAGE) is not None:
|
if conf.get(CONF_WILL_MESSAGE) is not None:
|
||||||
will_message = Message(**conf.get(CONF_WILL_MESSAGE))
|
will_message = Message(**conf.get(CONF_WILL_MESSAGE))
|
||||||
birth_message = None
|
birth_message = None # type: Optional[Message]
|
||||||
if conf.get(CONF_BIRTH_MESSAGE) is not None:
|
if conf.get(CONF_BIRTH_MESSAGE) is not None:
|
||||||
birth_message = Message(**conf.get(CONF_BIRTH_MESSAGE))
|
birth_message = Message(**conf.get(CONF_BIRTH_MESSAGE))
|
||||||
|
|
||||||
# Be able to override versions other than TLSv1.0 under Python3.6
|
# Be able to override versions other than TLSv1.0 under Python3.6
|
||||||
conf_tls_version = conf.get(CONF_TLS_VERSION)
|
conf_tls_version = conf.get(CONF_TLS_VERSION) # type: str
|
||||||
if conf_tls_version == '1.2':
|
if conf_tls_version == '1.2':
|
||||||
tls_version = ssl.PROTOCOL_TLSv1_2
|
tls_version = ssl.PROTOCOL_TLSv1_2
|
||||||
elif conf_tls_version == '1.1':
|
elif conf_tls_version == '1.1':
|
||||||
@ -376,24 +393,24 @@ def async_setup(hass: HomeAssistantType, config: ConfigType):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def async_stop_mqtt(event):
|
def async_stop_mqtt(event: Event):
|
||||||
"""Stop MQTT component."""
|
"""Stop MQTT component."""
|
||||||
yield from hass.data[DATA_MQTT].async_disconnect()
|
yield from hass.data[DATA_MQTT].async_disconnect()
|
||||||
|
|
||||||
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, async_stop_mqtt)
|
hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, async_stop_mqtt)
|
||||||
|
|
||||||
success = yield from hass.data[DATA_MQTT].async_connect()
|
success = yield from hass.data[DATA_MQTT].async_connect() # type: bool
|
||||||
if not success:
|
if not success:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def async_publish_service(call):
|
def async_publish_service(call: ServiceCall):
|
||||||
"""Handle MQTT publish service calls."""
|
"""Handle MQTT publish service calls."""
|
||||||
msg_topic = call.data[ATTR_TOPIC]
|
msg_topic = call.data[ATTR_TOPIC] # type: str
|
||||||
payload = call.data.get(ATTR_PAYLOAD)
|
payload = call.data.get(ATTR_PAYLOAD)
|
||||||
payload_template = call.data.get(ATTR_PAYLOAD_TEMPLATE)
|
payload_template = call.data.get(ATTR_PAYLOAD_TEMPLATE)
|
||||||
qos = call.data[ATTR_QOS]
|
qos = call.data[ATTR_QOS] # type: int
|
||||||
retain = call.data[ATTR_RETAIN]
|
retain = call.data[ATTR_RETAIN] # type: bool
|
||||||
if payload_template is not None:
|
if payload_template is not None:
|
||||||
try:
|
try:
|
||||||
payload = \
|
payload = \
|
||||||
@ -418,21 +435,36 @@ def async_setup(hass: HomeAssistantType, config: ConfigType):
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
Subscription = namedtuple('Subscription',
|
@attr.s(slots=True, frozen=True)
|
||||||
['topic', 'callback', 'qos', 'encoding'])
|
class Subscription(object):
|
||||||
Subscription.__new__.__defaults__ = (0, 'utf-8')
|
"""Class to hold data about an active subscription."""
|
||||||
|
|
||||||
Message = namedtuple('Message', ['topic', 'payload', 'qos', 'retain'])
|
topic = attr.ib(type=str)
|
||||||
Message.__new__.__defaults__ = (0, False)
|
callback = attr.ib(type=MessageCallbackType)
|
||||||
|
qos = attr.ib(type=int, default=0)
|
||||||
|
encoding = attr.ib(type=str, default='utf-8')
|
||||||
|
|
||||||
|
|
||||||
|
@attr.s(slots=True, frozen=True)
|
||||||
|
class Message(object):
|
||||||
|
"""MQTT Message."""
|
||||||
|
|
||||||
|
topic = attr.ib(type=str)
|
||||||
|
payload = attr.ib(type=PublishPayloadType)
|
||||||
|
qos = attr.ib(type=int, default=0)
|
||||||
|
retain = attr.ib(type=bool, default=False)
|
||||||
|
|
||||||
|
|
||||||
class MQTT(object):
|
class MQTT(object):
|
||||||
"""Home Assistant MQTT client."""
|
"""Home Assistant MQTT client."""
|
||||||
|
|
||||||
def __init__(self, hass, broker, port, client_id, keepalive, username,
|
def __init__(self, hass: HomeAssistantType, broker: str, port: int,
|
||||||
password, certificate, client_key, client_cert,
|
client_id: Optional[str], keepalive: Optional[int],
|
||||||
tls_insecure, protocol, will_message: Optional[Message],
|
username: Optional[str], password: Optional[str],
|
||||||
birth_message: Optional[Message], tls_version):
|
certificate: Optional[str], client_key: Optional[str],
|
||||||
|
client_cert: Optional[str], tls_insecure: Optional[bool],
|
||||||
|
protocol: Optional[str], will_message: Optional[Message],
|
||||||
|
birth_message: Optional[Message], tls_version) -> None:
|
||||||
"""Initialize Home Assistant MQTT client."""
|
"""Initialize Home Assistant MQTT client."""
|
||||||
import paho.mqtt.client as mqtt
|
import paho.mqtt.client as mqtt
|
||||||
|
|
||||||
@ -440,13 +472,13 @@ class MQTT(object):
|
|||||||
self.broker = broker
|
self.broker = broker
|
||||||
self.port = port
|
self.port = port
|
||||||
self.keepalive = keepalive
|
self.keepalive = keepalive
|
||||||
self.subscriptions = []
|
self.subscriptions = [] # type: List[Subscription]
|
||||||
self.birth_message = birth_message
|
self.birth_message = birth_message
|
||||||
self._mqttc = None
|
self._mqttc = None # type: mqtt.Client
|
||||||
self._paho_lock = asyncio.Lock(loop=hass.loop)
|
self._paho_lock = asyncio.Lock(loop=hass.loop)
|
||||||
|
|
||||||
if protocol == PROTOCOL_31:
|
if protocol == PROTOCOL_31:
|
||||||
proto = mqtt.MQTTv31
|
proto = mqtt.MQTTv31 # type: int
|
||||||
else:
|
else:
|
||||||
proto = mqtt.MQTTv311
|
proto = mqtt.MQTTv311
|
||||||
|
|
||||||
@ -470,11 +502,12 @@ class MQTT(object):
|
|||||||
self._mqttc.on_disconnect = self._mqtt_on_disconnect
|
self._mqttc.on_disconnect = self._mqtt_on_disconnect
|
||||||
self._mqttc.on_message = self._mqtt_on_message
|
self._mqttc.on_message = self._mqtt_on_message
|
||||||
|
|
||||||
if will_message:
|
if will_message is not None:
|
||||||
self._mqttc.will_set(*will_message)
|
self._mqttc.will_set(*attr.astuple(will_message))
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def async_publish(self, topic, payload, qos, retain):
|
def async_publish(self, topic: str, payload: PublishPayloadType, qos: int,
|
||||||
|
retain: bool):
|
||||||
"""Publish a MQTT message.
|
"""Publish a MQTT message.
|
||||||
|
|
||||||
This method must be run in the event loop and returns a coroutine.
|
This method must be run in the event loop and returns a coroutine.
|
||||||
@ -489,6 +522,7 @@ class MQTT(object):
|
|||||||
|
|
||||||
This method is a coroutine.
|
This method is a coroutine.
|
||||||
"""
|
"""
|
||||||
|
result = None # type: int
|
||||||
result = yield from self.hass.async_add_job(
|
result = yield from self.hass.async_add_job(
|
||||||
self._mqttc.connect, self.broker, self.port, self.keepalive)
|
self._mqttc.connect, self.broker, self.port, self.keepalive)
|
||||||
|
|
||||||
@ -500,6 +534,7 @@ class MQTT(object):
|
|||||||
|
|
||||||
return not result
|
return not result
|
||||||
|
|
||||||
|
@callback
|
||||||
def async_disconnect(self):
|
def async_disconnect(self):
|
||||||
"""Stop the MQTT client.
|
"""Stop the MQTT client.
|
||||||
|
|
||||||
@ -513,7 +548,8 @@ class MQTT(object):
|
|||||||
return self.hass.async_add_job(stop)
|
return self.hass.async_add_job(stop)
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def async_subscribe(self, topic, msg_callback, qos, encoding):
|
def async_subscribe(self, topic: str, msg_callback: MessageCallbackType,
|
||||||
|
qos: int, encoding: str):
|
||||||
"""Set up a subscription to a topic with the provided qos.
|
"""Set up a subscription to a topic with the provided qos.
|
||||||
|
|
||||||
This method is a coroutine.
|
This method is a coroutine.
|
||||||
@ -541,27 +577,31 @@ class MQTT(object):
|
|||||||
return async_remove
|
return async_remove
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def _async_unsubscribe(self, topic):
|
def _async_unsubscribe(self, topic: str):
|
||||||
"""Unsubscribe from a topic.
|
"""Unsubscribe from a topic.
|
||||||
|
|
||||||
This method is a coroutine.
|
This method is a coroutine.
|
||||||
"""
|
"""
|
||||||
with (yield from self._paho_lock):
|
with (yield from self._paho_lock):
|
||||||
|
result = None # type: int
|
||||||
result, _ = yield from self.hass.async_add_job(
|
result, _ = yield from self.hass.async_add_job(
|
||||||
self._mqttc.unsubscribe, topic)
|
self._mqttc.unsubscribe, topic)
|
||||||
_raise_on_error(result)
|
_raise_on_error(result)
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def _async_perform_subscription(self, topic, qos):
|
def _async_perform_subscription(self, topic: str,
|
||||||
|
qos: int):
|
||||||
"""Perform a paho-mqtt subscription."""
|
"""Perform a paho-mqtt subscription."""
|
||||||
_LOGGER.debug("Subscribing to %s", topic)
|
_LOGGER.debug("Subscribing to %s", topic)
|
||||||
|
|
||||||
with (yield from self._paho_lock):
|
with (yield from self._paho_lock):
|
||||||
|
result = None # type: int
|
||||||
result, _ = yield from self.hass.async_add_job(
|
result, _ = yield from self.hass.async_add_job(
|
||||||
self._mqttc.subscribe, topic, qos)
|
self._mqttc.subscribe, topic, qos)
|
||||||
_raise_on_error(result)
|
_raise_on_error(result)
|
||||||
|
|
||||||
def _mqtt_on_connect(self, _mqttc, _userdata, _flags, result_code):
|
def _mqtt_on_connect(self, _mqttc, _userdata, _flags,
|
||||||
|
result_code: int) -> None:
|
||||||
"""On connect callback.
|
"""On connect callback.
|
||||||
|
|
||||||
Resubscribe to all topics we were subscribed to and publish birth
|
Resubscribe to all topics we were subscribed to and publish birth
|
||||||
@ -584,21 +624,22 @@ class MQTT(object):
|
|||||||
self.hass.add_job(self._async_perform_subscription, topic, max_qos)
|
self.hass.add_job(self._async_perform_subscription, topic, max_qos)
|
||||||
|
|
||||||
if self.birth_message:
|
if self.birth_message:
|
||||||
self.hass.add_job(self.async_publish(*self.birth_message))
|
self.hass.add_job(
|
||||||
|
self.async_publish(*attr.astuple(self.birth_message)))
|
||||||
|
|
||||||
def _mqtt_on_message(self, _mqttc, _userdata, msg):
|
def _mqtt_on_message(self, _mqttc, _userdata, msg) -> None:
|
||||||
"""Message received callback."""
|
"""Message received callback."""
|
||||||
self.hass.add_job(self._mqtt_handle_message, msg)
|
self.hass.add_job(self._mqtt_handle_message, msg)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _mqtt_handle_message(self, msg):
|
def _mqtt_handle_message(self, msg) -> None:
|
||||||
_LOGGER.debug("Received message on %s: %s", msg.topic, msg.payload)
|
_LOGGER.debug("Received message on %s: %s", msg.topic, msg.payload)
|
||||||
|
|
||||||
for subscription in self.subscriptions:
|
for subscription in self.subscriptions:
|
||||||
if not _match_topic(subscription.topic, msg.topic):
|
if not _match_topic(subscription.topic, msg.topic):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
payload = msg.payload
|
payload = msg.payload # type: SubscribePayloadType
|
||||||
if subscription.encoding is not None:
|
if subscription.encoding is not None:
|
||||||
try:
|
try:
|
||||||
payload = msg.payload.decode(subscription.encoding)
|
payload = msg.payload.decode(subscription.encoding)
|
||||||
@ -612,7 +653,7 @@ class MQTT(object):
|
|||||||
self.hass.async_run_job(subscription.callback,
|
self.hass.async_run_job(subscription.callback,
|
||||||
msg.topic, payload, msg.qos)
|
msg.topic, payload, msg.qos)
|
||||||
|
|
||||||
def _mqtt_on_disconnect(self, _mqttc, _userdata, result_code):
|
def _mqtt_on_disconnect(self, _mqttc, _userdata, result_code: int) -> None:
|
||||||
"""Disconnected callback."""
|
"""Disconnected callback."""
|
||||||
# When disconnected because of calling disconnect()
|
# When disconnected because of calling disconnect()
|
||||||
if result_code == 0:
|
if result_code == 0:
|
||||||
@ -637,18 +678,18 @@ class MQTT(object):
|
|||||||
tries += 1
|
tries += 1
|
||||||
|
|
||||||
|
|
||||||
def _raise_on_error(result):
|
def _raise_on_error(result_code: int) -> None:
|
||||||
"""Raise error if error result."""
|
"""Raise error if error result."""
|
||||||
if result != 0:
|
if result_code != 0:
|
||||||
import paho.mqtt.client as mqtt
|
import paho.mqtt.client as mqtt
|
||||||
|
|
||||||
raise HomeAssistantError(
|
raise HomeAssistantError(
|
||||||
'Error talking to MQTT: {}'.format(mqtt.error_string(result)))
|
'Error talking to MQTT: {}'.format(mqtt.error_string(result_code)))
|
||||||
|
|
||||||
|
|
||||||
def _match_topic(subscription, topic):
|
def _match_topic(subscription: str, topic: str) -> bool:
|
||||||
"""Test if topic matches subscription."""
|
"""Test if topic matches subscription."""
|
||||||
reg_ex_parts = []
|
reg_ex_parts = [] # type: List[str]
|
||||||
suffix = ""
|
suffix = ""
|
||||||
if subscription.endswith('#'):
|
if subscription.endswith('#'):
|
||||||
subscription = subscription[:-2]
|
subscription = subscription[:-2]
|
||||||
@ -670,22 +711,26 @@ def _match_topic(subscription, topic):
|
|||||||
class MqttAvailability(Entity):
|
class MqttAvailability(Entity):
|
||||||
"""Mixin used for platforms that report availability."""
|
"""Mixin used for platforms that report availability."""
|
||||||
|
|
||||||
def __init__(self, availability_topic, qos, payload_available,
|
def __init__(self, availability_topic: Optional[str], qos: Optional[int],
|
||||||
payload_not_available):
|
payload_available: Optional[str],
|
||||||
|
payload_not_available: Optional[str]) -> None:
|
||||||
"""Initialize the availability mixin."""
|
"""Initialize the availability mixin."""
|
||||||
self._availability_topic = availability_topic
|
self._availability_topic = availability_topic
|
||||||
self._availability_qos = qos
|
self._availability_qos = qos
|
||||||
self._available = availability_topic is None
|
self._available = availability_topic is None # type: bool
|
||||||
self._payload_available = payload_available
|
self._payload_available = payload_available
|
||||||
self._payload_not_available = payload_not_available
|
self._payload_not_available = payload_not_available
|
||||||
|
|
||||||
|
@asyncio.coroutine
|
||||||
def async_added_to_hass(self):
|
def async_added_to_hass(self):
|
||||||
"""Subscribe mqtt events.
|
"""Subscribe mqtt events.
|
||||||
|
|
||||||
This method must be run in the event loop and returns a coroutine.
|
This method must be run in the event loop and returns a coroutine.
|
||||||
"""
|
"""
|
||||||
@callback
|
@callback
|
||||||
def availability_message_received(topic, payload, qos):
|
def availability_message_received(topic: str,
|
||||||
|
payload: SubscribePayloadType,
|
||||||
|
qos: int) -> None:
|
||||||
"""Handle a new received MQTT availability message."""
|
"""Handle a new received MQTT availability message."""
|
||||||
if payload == self._payload_available:
|
if payload == self._payload_available:
|
||||||
self._available = True
|
self._available = True
|
||||||
|
@ -8,6 +8,7 @@ import homeassistant.core
|
|||||||
GPSType = Tuple[float, float]
|
GPSType = Tuple[float, float]
|
||||||
ConfigType = Dict[str, Any]
|
ConfigType = Dict[str, Any]
|
||||||
HomeAssistantType = homeassistant.core.HomeAssistant
|
HomeAssistantType = homeassistant.core.HomeAssistant
|
||||||
|
ServiceDataType = Dict[str, Any]
|
||||||
|
|
||||||
# Custom type for recorder Queries
|
# Custom type for recorder Queries
|
||||||
QueryType = Any
|
QueryType = Any
|
||||||
|
Loading…
x
Reference in New Issue
Block a user