Use pytest for more MQTT tests (#36859)

* Use pytest for more MQTT tests

* Address review comments

* Break out PAHO client mock in separate fixture.

* tweak.
This commit is contained in:
Erik Montnemery 2020-06-22 22:02:29 +02:00 committed by GitHub
parent 3a83f4bdbe
commit a2e2c35011
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 570 additions and 509 deletions

View File

@ -366,11 +366,19 @@ def async_publish(
@bind_hass @bind_hass
def publish_template( def publish_template(
hass: HomeAssistantType, topic, payload_template, qos=None, retain=None hass: HomeAssistantType, topic, payload_template, qos=None, retain=None
) -> None:
"""Publish message to an MQTT topic."""
hass.add_job(async_publish_template, hass, topic, payload_template, qos, retain)
@bind_hass
def async_publish_template(
hass: HomeAssistantType, topic, payload_template, qos=None, retain=None
) -> None: ) -> None:
"""Publish message to an MQTT topic using a template payload.""" """Publish message to an MQTT topic using a template payload."""
data = _build_publish_data(topic, qos, retain) data = _build_publish_data(topic, qos, retain)
data[ATTR_PAYLOAD_TEMPLATE] = payload_template data[ATTR_PAYLOAD_TEMPLATE] = payload_template
hass.services.call(DOMAIN, SERVICE_PUBLISH, data) hass.async_create_task(hass.services.async_call(DOMAIN, SERVICE_PUBLISH, data))
def wrap_msg_callback(msg_callback: MessageCallbackType) -> MessageCallbackType: def wrap_msg_callback(msg_callback: MessageCallbackType) -> MessageCallbackType:

View File

@ -344,10 +344,13 @@ async def async_mock_mqtt_component(hass, config=None):
assert result assert result
await hass.async_block_till_done() await hass.async_block_till_done()
hass.data["mqtt"] = MagicMock( mqtt_component_mock = MagicMock(
spec_set=hass.data["mqtt"], wraps=hass.data["mqtt"] spec_set=hass.data["mqtt"], wraps=hass.data["mqtt"]
) )
hass.data["mqtt"].connected = mqtt_component_mock.connected
mqtt_component_mock._mqttc = mock_client
hass.data["mqtt"] = mqtt_component_mock
return hass.data["mqtt"] return hass.data["mqtt"]

View File

@ -1,12 +1,52 @@
"""Test fixtures for mqtt component.""" """Test fixtures for mqtt component."""
import pytest import pytest
from tests.common import async_mock_mqtt_component from homeassistant import core as ha
from homeassistant.components import mqtt
from homeassistant.setup import async_setup_component
from tests.async_mock import MagicMock, patch
from tests.common import async_fire_mqtt_message
@pytest.fixture @pytest.fixture
def mqtt_mock(loop, hass): def mqtt_config():
"""Fixture to mock MQTT.""" """Fixture to allow overriding MQTT config."""
client = loop.run_until_complete(async_mock_mqtt_component(hass)) return None
client.reset_mock()
return client
@pytest.fixture
def mqtt_client_mock(hass):
"""Fixture to mock MQTT client."""
@ha.callback
def _async_fire_mqtt_message(topic, payload, qos, retain):
async_fire_mqtt_message(hass, topic, payload, qos, retain)
with patch("paho.mqtt.client.Client") as mock_client:
mock_client = mock_client.return_value
mock_client.connect.return_value = 0
mock_client.subscribe.return_value = (0, 0)
mock_client.unsubscribe.return_value = (0, 0)
mock_client.publish.side_effect = _async_fire_mqtt_message
yield mock_client
@pytest.fixture
async def mqtt_mock(hass, mqtt_client_mock, mqtt_config):
"""Fixture to mock MQTT component."""
if mqtt_config is None:
mqtt_config = {mqtt.CONF_BROKER: "mock-broker"}
result = await async_setup_component(hass, mqtt.DOMAIN, {mqtt.DOMAIN: mqtt_config})
assert result
await hass.async_block_till_done()
mqtt_component_mock = MagicMock(spec_set=hass.data["mqtt"], wraps=hass.data["mqtt"])
hass.data["mqtt"].connected = mqtt_component_mock.connected
mqtt_component_mock._mqttc = mqtt_client_mock
hass.data["mqtt"] = mqtt_component_mock
component = hass.data["mqtt"]
component.reset_mock()
return component

View File

@ -2,7 +2,6 @@
from datetime import datetime, timedelta from datetime import datetime, timedelta
import json import json
import ssl import ssl
import unittest
import pytest import pytest
import voluptuous as vol import voluptuous as vol
@ -27,13 +26,8 @@ from tests.common import (
MockConfigEntry, MockConfigEntry,
async_fire_mqtt_message, async_fire_mqtt_message,
async_fire_time_changed, async_fire_time_changed,
async_mock_mqtt_component,
fire_mqtt_message,
get_test_home_assistant,
mock_device_registry, mock_device_registry,
mock_mqtt_component,
mock_registry, mock_registry,
threadsafe_coroutine_factory,
) )
from tests.testing_config.custom_components.test.sensor import DEVICE_CLASSES from tests.testing_config.custom_components.test.sensor import DEVICE_CLASSES
@ -64,85 +58,74 @@ def mock_mqtt():
yield mock_mqtt yield mock_mqtt
async def async_mock_mqtt_client(hass, config=None): @pytest.fixture
"""Mock the MQTT paho client.""" def calls():
if config is None: """Fixture to record calls."""
config = {mqtt.CONF_BROKER: "mock-broker"} return []
with patch("paho.mqtt.client.Client") as mock_client:
mock_client().connect.return_value = 0
mock_client().subscribe.return_value = (0, 0)
mock_client().unsubscribe.return_value = (0, 0)
mock_client().publish.return_value = (0, 0)
result = await async_setup_component(hass, mqtt.DOMAIN, {mqtt.DOMAIN: config})
assert result
await hass.async_block_till_done()
return mock_client()
mock_mqtt_client = threadsafe_coroutine_factory(async_mock_mqtt_client) @pytest.fixture
def record_calls(calls):
"""Fixture to record calls."""
# pylint: disable=invalid-name
class TestMQTTComponent(unittest.TestCase):
"""Test the MQTT component."""
def setUp(self): # pylint: disable=invalid-name
"""Set up things to be run when tests are started."""
self.hass = get_test_home_assistant()
mock_mqtt_component(self.hass)
self.calls = []
self.addCleanup(self.tear_down_cleanup)
def tear_down_cleanup(self):
"""Stop everything that was started."""
self.hass.stop()
@callback @callback
def record_calls(self, *args): def record_calls(*args):
"""Record calls.""" """Record calls."""
self.calls.append(args) calls.append(args)
def aiohttp_client_stops_on_home_assistant_start(self): return record_calls
async def test_mqtt_connects_on_home_assistant_mqtt_setup(
hass, mqtt_client_mock, mqtt_mock
):
"""Test if client is connected after mqtt init on bootstrap."""
assert mqtt_client_mock.connect.call_count == 1
async def test_mqtt_disconnects_on_home_assistant_stop(hass, mqtt_mock):
"""Test if client stops on HA stop.""" """Test if client stops on HA stop."""
self.hass.bus.fire(EVENT_HOMEASSISTANT_STOP) hass.bus.fire(EVENT_HOMEASSISTANT_STOP)
self.hass.block_till_done() await hass.async_block_till_done()
assert self.hass.data["mqtt"].async_disconnect.called await hass.async_block_till_done()
assert mqtt_mock.async_disconnect.called
def test_publish_calls_service(self):
async def test_publish_calls_service(hass, mqtt_mock, calls, record_calls):
"""Test the publishing of call to services.""" """Test the publishing of call to services."""
self.hass.bus.listen_once(EVENT_CALL_SERVICE, self.record_calls) hass.bus.async_listen_once(EVENT_CALL_SERVICE, record_calls)
mqtt.publish(self.hass, "test-topic", "test-payload") mqtt.async_publish(hass, "test-topic", "test-payload")
self.hass.block_till_done() await hass.async_block_till_done()
assert len(self.calls) == 1 assert len(calls) == 1
assert self.calls[0][0].data["service_data"][mqtt.ATTR_TOPIC] == "test-topic" assert calls[0][0].data["service_data"][mqtt.ATTR_TOPIC] == "test-topic"
assert ( assert calls[0][0].data["service_data"][mqtt.ATTR_PAYLOAD] == "test-payload"
self.calls[0][0].data["service_data"][mqtt.ATTR_PAYLOAD] == "test-payload"
)
def test_service_call_without_topic_does_not_publish(self):
async def test_service_call_without_topic_does_not_publish(hass, mqtt_mock):
"""Test the service call if topic is missing.""" """Test the service call if topic is missing."""
self.hass.bus.fire( hass.bus.fire(
EVENT_CALL_SERVICE, EVENT_CALL_SERVICE,
{ATTR_DOMAIN: mqtt.DOMAIN, ATTR_SERVICE: mqtt.SERVICE_PUBLISH}, {ATTR_DOMAIN: mqtt.DOMAIN, ATTR_SERVICE: mqtt.SERVICE_PUBLISH},
) )
self.hass.block_till_done() await hass.async_block_till_done()
assert not self.hass.data["mqtt"].async_publish.called assert not mqtt_mock.async_publish.called
def test_service_call_with_template_payload_renders_template(self):
async def test_service_call_with_template_payload_renders_template(hass, mqtt_mock):
"""Test the service call with rendered template. """Test the service call with rendered template.
If 'payload_template' is provided and 'payload' is not, then render it. If 'payload_template' is provided and 'payload' is not, then render it.
""" """
mqtt.publish_template(self.hass, "test/topic", "{{ 1+1 }}") mqtt.async_publish_template(hass, "test/topic", "{{ 1+1 }}")
self.hass.block_till_done() await hass.async_block_till_done()
assert self.hass.data["mqtt"].async_publish.called assert mqtt_mock.async_publish.called
assert self.hass.data["mqtt"].async_publish.call_args[0][1] == "2" assert mqtt_mock.async_publish.call_args[0][1] == "2"
def test_service_call_with_payload_doesnt_render_template(self):
async def test_service_call_with_payload_doesnt_render_template(hass, mqtt_mock):
"""Test the service call with unrendered template. """Test the service call with unrendered template.
If both 'payload' and 'payload_template' are provided then fail. If both 'payload' and 'payload_template' are provided then fail.
@ -150,7 +133,7 @@ class TestMQTTComponent(unittest.TestCase):
payload = "not a template" payload = "not a template"
payload_template = "a template" payload_template = "a template"
with pytest.raises(vol.Invalid): with pytest.raises(vol.Invalid):
self.hass.services.call( await hass.services.async_call(
mqtt.DOMAIN, mqtt.DOMAIN,
mqtt.SERVICE_PUBLISH, mqtt.SERVICE_PUBLISH,
{ {
@ -160,14 +143,15 @@ class TestMQTTComponent(unittest.TestCase):
}, },
blocking=True, blocking=True,
) )
assert not self.hass.data["mqtt"].async_publish.called assert not mqtt_mock.async_publish.called
def test_service_call_with_ascii_qos_retain_flags(self):
async def test_service_call_with_ascii_qos_retain_flags(hass, mqtt_mock):
"""Test the service call with args that can be misinterpreted. """Test the service call with args that can be misinterpreted.
Empty payload message and ascii formatted qos and retain flags. Empty payload message and ascii formatted qos and retain flags.
""" """
self.hass.services.call( await hass.services.async_call(
mqtt.DOMAIN, mqtt.DOMAIN,
mqtt.SERVICE_PUBLISH, mqtt.SERVICE_PUBLISH,
{ {
@ -178,11 +162,12 @@ class TestMQTTComponent(unittest.TestCase):
}, },
blocking=True, blocking=True,
) )
assert self.hass.data["mqtt"].async_publish.called assert mqtt_mock.async_publish.called
assert self.hass.data["mqtt"].async_publish.call_args[0][2] == 2 assert mqtt_mock.async_publish.call_args[0][2] == 2
assert not self.hass.data["mqtt"].async_publish.call_args[0][3] assert not mqtt_mock.async_publish.call_args[0][3]
def test_validate_topic(self):
def test_validate_topic():
"""Test topic name/filter validation.""" """Test topic name/filter validation."""
# Invalid UTF-8, must not contain U+D800 to U+DFFF. # Invalid UTF-8, must not contain U+D800 to U+DFFF.
with pytest.raises(vol.Invalid): with pytest.raises(vol.Invalid):
@ -207,7 +192,8 @@ class TestMQTTComponent(unittest.TestCase):
mqtt.valid_topic("\u009F") mqtt.valid_topic("\u009F")
mqtt.valid_topic("\uffff") mqtt.valid_topic("\uffff")
def test_validate_subscribe_topic(self):
def test_validate_subscribe_topic():
"""Test invalid subscribe topics.""" """Test invalid subscribe topics."""
mqtt.valid_subscribe_topic("#") mqtt.valid_subscribe_topic("#")
mqtt.valid_subscribe_topic("sport/#") mqtt.valid_subscribe_topic("sport/#")
@ -235,7 +221,8 @@ class TestMQTTComponent(unittest.TestCase):
mqtt.valid_subscribe_topic("+/+") mqtt.valid_subscribe_topic("+/+")
mqtt.valid_subscribe_topic("$SYS/#") mqtt.valid_subscribe_topic("$SYS/#")
def test_validate_publish_topic(self):
def test_validate_publish_topic():
"""Test invalid publish topics.""" """Test invalid publish topics."""
with pytest.raises(vol.Invalid): with pytest.raises(vol.Invalid):
mqtt.valid_publish_topic("pub+") mqtt.valid_publish_topic("pub+")
@ -250,15 +237,14 @@ class TestMQTTComponent(unittest.TestCase):
# Topic names beginning with $ SHOULD NOT be used, but can # Topic names beginning with $ SHOULD NOT be used, but can
mqtt.valid_publish_topic("$SYS/") mqtt.valid_publish_topic("$SYS/")
def test_entity_device_info_schema(self):
def test_entity_device_info_schema():
"""Test MQTT entity device info validation.""" """Test MQTT entity device info validation."""
# just identifier # just identifier
mqtt.MQTT_ENTITY_DEVICE_INFO_SCHEMA({"identifiers": ["abcd"]}) mqtt.MQTT_ENTITY_DEVICE_INFO_SCHEMA({"identifiers": ["abcd"]})
mqtt.MQTT_ENTITY_DEVICE_INFO_SCHEMA({"identifiers": "abcd"}) mqtt.MQTT_ENTITY_DEVICE_INFO_SCHEMA({"identifiers": "abcd"})
# just connection # just connection
mqtt.MQTT_ENTITY_DEVICE_INFO_SCHEMA( mqtt.MQTT_ENTITY_DEVICE_INFO_SCHEMA({"connections": [["mac", "02:5b:26:a8:dc:12"]]})
{"connections": [["mac", "02:5b:26:a8:dc:12"]]}
)
# full device info # full device info
mqtt.MQTT_ENTITY_DEVICE_INFO_SCHEMA( mqtt.MQTT_ENTITY_DEVICE_INFO_SCHEMA(
{ {
@ -299,72 +285,53 @@ class TestMQTTComponent(unittest.TestCase):
) )
# pylint: disable=invalid-name async def test_receiving_non_utf8_message_gets_logged(
class TestMQTTCallbacks(unittest.TestCase): hass, mqtt_mock, calls, record_calls, caplog
"""Test the MQTT callbacks.""" ):
def setUp(self): # pylint: disable=invalid-name
"""Set up things to be run when tests are started."""
self.hass = get_test_home_assistant()
mock_mqtt_client(self.hass)
self.calls = []
self.addCleanup(self.tear_down_cleanup)
def tear_down_cleanup(self):
"""Stop everything that was started."""
self.hass.stop()
@callback
def record_calls(self, *args):
"""Record calls."""
self.calls.append(args)
def aiohttp_client_starts_on_home_assistant_mqtt_setup(self):
"""Test if client is connected after mqtt init on bootstrap."""
assert self.hass.data["mqtt"]._mqttc.connect.call_count == 1
def test_receiving_non_utf8_message_gets_logged(self):
"""Test receiving a non utf8 encoded message.""" """Test receiving a non utf8 encoded message."""
mqtt.subscribe(self.hass, "test-topic", self.record_calls) await mqtt.async_subscribe(hass, "test-topic", record_calls)
with self.assertLogs(level="WARNING") as test_handle: async_fire_mqtt_message(hass, "test-topic", b"\x9a")
fire_mqtt_message(self.hass, "test-topic", b"\x9a")
self.hass.block_till_done() await hass.async_block_till_done()
assert ( assert (
"WARNING:homeassistant.components.mqtt:Can't decode payload " "Can't decode payload b'\\x9a' on test-topic with encoding utf-8" in caplog.text
"b'\\x9a' on test-topic with encoding utf-8" in test_handle.output[0]
) )
def test_all_subscriptions_run_when_decode_fails(self):
async def test_all_subscriptions_run_when_decode_fails(
hass, mqtt_mock, calls, record_calls
):
"""Test all other subscriptions still run when decode fails for one.""" """Test all other subscriptions still run when decode fails for one."""
mqtt.subscribe(self.hass, "test-topic", self.record_calls, encoding="ascii") await mqtt.async_subscribe(hass, "test-topic", record_calls, encoding="ascii")
mqtt.subscribe(self.hass, "test-topic", self.record_calls) await mqtt.async_subscribe(hass, "test-topic", record_calls)
fire_mqtt_message(self.hass, "test-topic", TEMP_CELSIUS) async_fire_mqtt_message(hass, "test-topic", TEMP_CELSIUS)
self.hass.block_till_done() await hass.async_block_till_done()
assert len(self.calls) == 1 assert len(calls) == 1
def test_subscribe_topic(self):
async def test_subscribe_topic(hass, mqtt_mock, calls, record_calls):
"""Test the subscription of a topic.""" """Test the subscription of a topic."""
unsub = mqtt.subscribe(self.hass, "test-topic", self.record_calls) unsub = await mqtt.async_subscribe(hass, "test-topic", record_calls)
fire_mqtt_message(self.hass, "test-topic", "test-payload") async_fire_mqtt_message(hass, "test-topic", "test-payload")
self.hass.block_till_done() await hass.async_block_till_done()
assert len(self.calls) == 1 assert len(calls) == 1
assert self.calls[0][0].topic == "test-topic" assert calls[0][0].topic == "test-topic"
assert self.calls[0][0].payload == "test-payload" assert calls[0][0].payload == "test-payload"
unsub() unsub()
fire_mqtt_message(self.hass, "test-topic", "test-payload") async_fire_mqtt_message(hass, "test-topic", "test-payload")
self.hass.block_till_done() await hass.async_block_till_done()
assert len(self.calls) == 1 assert len(calls) == 1
def test_subscribe_deprecated(self):
async def test_subscribe_deprecated(hass, mqtt_mock):
"""Test the subscription of a topic using deprecated callback signature.""" """Test the subscription of a topic using deprecated callback signature."""
calls = [] calls = []
@ -373,23 +340,24 @@ class TestMQTTCallbacks(unittest.TestCase):
"""Record calls.""" """Record calls."""
calls.append((topic, payload, qos)) calls.append((topic, payload, qos))
unsub = mqtt.subscribe(self.hass, "test-topic", record_calls) unsub = await mqtt.async_subscribe(hass, "test-topic", record_calls)
fire_mqtt_message(self.hass, "test-topic", "test-payload") async_fire_mqtt_message(hass, "test-topic", "test-payload")
self.hass.block_till_done() await hass.async_block_till_done()
assert len(calls) == 1 assert len(calls) == 1
assert calls[0][0] == "test-topic" assert calls[0][0] == "test-topic"
assert calls[0][1] == "test-payload" assert calls[0][1] == "test-payload"
unsub() unsub()
fire_mqtt_message(self.hass, "test-topic", "test-payload") async_fire_mqtt_message(hass, "test-topic", "test-payload")
self.hass.block_till_done() await hass.async_block_till_done()
assert len(calls) == 1 assert len(calls) == 1
def test_subscribe_deprecated_async(self):
async def test_subscribe_deprecated_async(hass, mqtt_mock):
"""Test the subscription of a topic using deprecated callback signature.""" """Test the subscription of a topic using deprecated callback signature."""
calls = [] calls = []
@ -398,285 +366,334 @@ class TestMQTTCallbacks(unittest.TestCase):
"""Record calls.""" """Record calls."""
calls.append((topic, payload, qos)) calls.append((topic, payload, qos))
unsub = mqtt.subscribe(self.hass, "test-topic", record_calls) unsub = await mqtt.async_subscribe(hass, "test-topic", record_calls)
fire_mqtt_message(self.hass, "test-topic", "test-payload") async_fire_mqtt_message(hass, "test-topic", "test-payload")
self.hass.block_till_done() await hass.async_block_till_done()
assert len(calls) == 1 assert len(calls) == 1
assert calls[0][0] == "test-topic" assert calls[0][0] == "test-topic"
assert calls[0][1] == "test-payload" assert calls[0][1] == "test-payload"
unsub() unsub()
fire_mqtt_message(self.hass, "test-topic", "test-payload") async_fire_mqtt_message(hass, "test-topic", "test-payload")
self.hass.block_till_done() await hass.async_block_till_done()
assert len(calls) == 1 assert len(calls) == 1
def test_subscribe_topic_not_match(self):
async def test_subscribe_topic_not_match(hass, mqtt_mock, calls, record_calls):
"""Test if subscribed topic is not a match.""" """Test if subscribed topic is not a match."""
mqtt.subscribe(self.hass, "test-topic", self.record_calls) await mqtt.async_subscribe(hass, "test-topic", record_calls)
fire_mqtt_message(self.hass, "another-test-topic", "test-payload") async_fire_mqtt_message(hass, "another-test-topic", "test-payload")
self.hass.block_till_done() await hass.async_block_till_done()
assert len(self.calls) == 0 assert len(calls) == 0
def test_subscribe_topic_level_wildcard(self):
async def test_subscribe_topic_level_wildcard(hass, mqtt_mock, calls, record_calls):
"""Test the subscription of wildcard topics.""" """Test the subscription of wildcard topics."""
mqtt.subscribe(self.hass, "test-topic/+/on", self.record_calls) await mqtt.async_subscribe(hass, "test-topic/+/on", record_calls)
fire_mqtt_message(self.hass, "test-topic/bier/on", "test-payload") async_fire_mqtt_message(hass, "test-topic/bier/on", "test-payload")
self.hass.block_till_done() await hass.async_block_till_done()
assert len(self.calls) == 1 assert len(calls) == 1
assert self.calls[0][0].topic == "test-topic/bier/on" assert calls[0][0].topic == "test-topic/bier/on"
assert self.calls[0][0].payload == "test-payload" assert calls[0][0].payload == "test-payload"
def test_subscribe_topic_level_wildcard_no_subtree_match(self):
async def test_subscribe_topic_level_wildcard_no_subtree_match(
hass, mqtt_mock, calls, record_calls
):
"""Test the subscription of wildcard topics.""" """Test the subscription of wildcard topics."""
mqtt.subscribe(self.hass, "test-topic/+/on", self.record_calls) await mqtt.async_subscribe(hass, "test-topic/+/on", record_calls)
fire_mqtt_message(self.hass, "test-topic/bier", "test-payload") async_fire_mqtt_message(hass, "test-topic/bier", "test-payload")
self.hass.block_till_done() await hass.async_block_till_done()
assert len(self.calls) == 0 assert len(calls) == 0
def test_subscribe_topic_level_wildcard_root_topic_no_subtree_match(self):
async def test_subscribe_topic_level_wildcard_root_topic_no_subtree_match(
hass, mqtt_mock, calls, record_calls
):
"""Test the subscription of wildcard topics.""" """Test the subscription of wildcard topics."""
mqtt.subscribe(self.hass, "test-topic/#", self.record_calls) await mqtt.async_subscribe(hass, "test-topic/#", record_calls)
fire_mqtt_message(self.hass, "test-topic-123", "test-payload") async_fire_mqtt_message(hass, "test-topic-123", "test-payload")
self.hass.block_till_done() await hass.async_block_till_done()
assert len(self.calls) == 0 assert len(calls) == 0
def test_subscribe_topic_subtree_wildcard_subtree_topic(self):
async def test_subscribe_topic_subtree_wildcard_subtree_topic(
hass, mqtt_mock, calls, record_calls
):
"""Test the subscription of wildcard topics.""" """Test the subscription of wildcard topics."""
mqtt.subscribe(self.hass, "test-topic/#", self.record_calls) await mqtt.async_subscribe(hass, "test-topic/#", record_calls)
fire_mqtt_message(self.hass, "test-topic/bier/on", "test-payload") async_fire_mqtt_message(hass, "test-topic/bier/on", "test-payload")
self.hass.block_till_done() await hass.async_block_till_done()
assert len(self.calls) == 1 assert len(calls) == 1
assert self.calls[0][0].topic == "test-topic/bier/on" assert calls[0][0].topic == "test-topic/bier/on"
assert self.calls[0][0].payload == "test-payload" assert calls[0][0].payload == "test-payload"
def test_subscribe_topic_subtree_wildcard_root_topic(self):
async def test_subscribe_topic_subtree_wildcard_root_topic(
hass, mqtt_mock, calls, record_calls
):
"""Test the subscription of wildcard topics.""" """Test the subscription of wildcard topics."""
mqtt.subscribe(self.hass, "test-topic/#", self.record_calls) await mqtt.async_subscribe(hass, "test-topic/#", record_calls)
fire_mqtt_message(self.hass, "test-topic", "test-payload") async_fire_mqtt_message(hass, "test-topic", "test-payload")
self.hass.block_till_done() await hass.async_block_till_done()
assert len(self.calls) == 1 assert len(calls) == 1
assert self.calls[0][0].topic == "test-topic" assert calls[0][0].topic == "test-topic"
assert self.calls[0][0].payload == "test-payload" assert calls[0][0].payload == "test-payload"
def test_subscribe_topic_subtree_wildcard_no_match(self):
async def test_subscribe_topic_subtree_wildcard_no_match(
hass, mqtt_mock, calls, record_calls
):
"""Test the subscription of wildcard topics.""" """Test the subscription of wildcard topics."""
mqtt.subscribe(self.hass, "test-topic/#", self.record_calls) await mqtt.async_subscribe(hass, "test-topic/#", record_calls)
fire_mqtt_message(self.hass, "another-test-topic", "test-payload") async_fire_mqtt_message(hass, "another-test-topic", "test-payload")
self.hass.block_till_done() await hass.async_block_till_done()
assert len(self.calls) == 0 assert len(calls) == 0
def test_subscribe_topic_level_wildcard_and_wildcard_root_topic(self):
async def test_subscribe_topic_level_wildcard_and_wildcard_root_topic(
hass, mqtt_mock, calls, record_calls
):
"""Test the subscription of wildcard topics.""" """Test the subscription of wildcard topics."""
mqtt.subscribe(self.hass, "+/test-topic/#", self.record_calls) await mqtt.async_subscribe(hass, "+/test-topic/#", record_calls)
fire_mqtt_message(self.hass, "hi/test-topic", "test-payload") async_fire_mqtt_message(hass, "hi/test-topic", "test-payload")
self.hass.block_till_done() await hass.async_block_till_done()
assert len(self.calls) == 1 assert len(calls) == 1
assert self.calls[0][0].topic == "hi/test-topic" assert calls[0][0].topic == "hi/test-topic"
assert self.calls[0][0].payload == "test-payload" assert calls[0][0].payload == "test-payload"
def test_subscribe_topic_level_wildcard_and_wildcard_subtree_topic(self):
async def test_subscribe_topic_level_wildcard_and_wildcard_subtree_topic(
hass, mqtt_mock, calls, record_calls
):
"""Test the subscription of wildcard topics.""" """Test the subscription of wildcard topics."""
mqtt.subscribe(self.hass, "+/test-topic/#", self.record_calls) await mqtt.async_subscribe(hass, "+/test-topic/#", record_calls)
fire_mqtt_message(self.hass, "hi/test-topic/here-iam", "test-payload") async_fire_mqtt_message(hass, "hi/test-topic/here-iam", "test-payload")
self.hass.block_till_done() await hass.async_block_till_done()
assert len(self.calls) == 1 assert len(calls) == 1
assert self.calls[0][0].topic == "hi/test-topic/here-iam" assert calls[0][0].topic == "hi/test-topic/here-iam"
assert self.calls[0][0].payload == "test-payload" assert calls[0][0].payload == "test-payload"
def test_subscribe_topic_level_wildcard_and_wildcard_level_no_match(self):
async def test_subscribe_topic_level_wildcard_and_wildcard_level_no_match(
hass, mqtt_mock, calls, record_calls
):
"""Test the subscription of wildcard topics.""" """Test the subscription of wildcard topics."""
mqtt.subscribe(self.hass, "+/test-topic/#", self.record_calls) await mqtt.async_subscribe(hass, "+/test-topic/#", record_calls)
fire_mqtt_message(self.hass, "hi/here-iam/test-topic", "test-payload") async_fire_mqtt_message(hass, "hi/here-iam/test-topic", "test-payload")
self.hass.block_till_done() await hass.async_block_till_done()
assert len(self.calls) == 0 assert len(calls) == 0
def test_subscribe_topic_level_wildcard_and_wildcard_no_match(self):
async def test_subscribe_topic_level_wildcard_and_wildcard_no_match(
hass, mqtt_mock, calls, record_calls
):
"""Test the subscription of wildcard topics.""" """Test the subscription of wildcard topics."""
mqtt.subscribe(self.hass, "+/test-topic/#", self.record_calls) await mqtt.async_subscribe(hass, "+/test-topic/#", record_calls)
fire_mqtt_message(self.hass, "hi/another-test-topic", "test-payload") async_fire_mqtt_message(hass, "hi/another-test-topic", "test-payload")
self.hass.block_till_done() await hass.async_block_till_done()
assert len(self.calls) == 0 assert len(calls) == 0
def test_subscribe_topic_sys_root(self):
async def test_subscribe_topic_sys_root(hass, mqtt_mock, calls, record_calls):
"""Test the subscription of $ root topics.""" """Test the subscription of $ root topics."""
mqtt.subscribe(self.hass, "$test-topic/subtree/on", self.record_calls) await mqtt.async_subscribe(hass, "$test-topic/subtree/on", record_calls)
fire_mqtt_message(self.hass, "$test-topic/subtree/on", "test-payload") async_fire_mqtt_message(hass, "$test-topic/subtree/on", "test-payload")
self.hass.block_till_done() await hass.async_block_till_done()
assert len(self.calls) == 1 assert len(calls) == 1
assert self.calls[0][0].topic == "$test-topic/subtree/on" assert calls[0][0].topic == "$test-topic/subtree/on"
assert self.calls[0][0].payload == "test-payload" assert calls[0][0].payload == "test-payload"
def test_subscribe_topic_sys_root_and_wildcard_topic(self):
async def test_subscribe_topic_sys_root_and_wildcard_topic(
hass, mqtt_mock, calls, record_calls
):
"""Test the subscription of $ root and wildcard topics.""" """Test the subscription of $ root and wildcard topics."""
mqtt.subscribe(self.hass, "$test-topic/#", self.record_calls) await mqtt.async_subscribe(hass, "$test-topic/#", record_calls)
fire_mqtt_message(self.hass, "$test-topic/some-topic", "test-payload") async_fire_mqtt_message(hass, "$test-topic/some-topic", "test-payload")
self.hass.block_till_done() await hass.async_block_till_done()
assert len(self.calls) == 1 assert len(calls) == 1
assert self.calls[0][0].topic == "$test-topic/some-topic" assert calls[0][0].topic == "$test-topic/some-topic"
assert self.calls[0][0].payload == "test-payload" assert calls[0][0].payload == "test-payload"
def test_subscribe_topic_sys_root_and_wildcard_subtree_topic(self):
async def test_subscribe_topic_sys_root_and_wildcard_subtree_topic(
hass, mqtt_mock, calls, record_calls
):
"""Test the subscription of $ root and wildcard subtree topics.""" """Test the subscription of $ root and wildcard subtree topics."""
mqtt.subscribe(self.hass, "$test-topic/subtree/#", self.record_calls) await mqtt.async_subscribe(hass, "$test-topic/subtree/#", record_calls)
fire_mqtt_message(self.hass, "$test-topic/subtree/some-topic", "test-payload") async_fire_mqtt_message(hass, "$test-topic/subtree/some-topic", "test-payload")
self.hass.block_till_done() await hass.async_block_till_done()
assert len(self.calls) == 1 assert len(calls) == 1
assert self.calls[0][0].topic == "$test-topic/subtree/some-topic" assert calls[0][0].topic == "$test-topic/subtree/some-topic"
assert self.calls[0][0].payload == "test-payload" assert calls[0][0].payload == "test-payload"
def test_subscribe_special_characters(self):
async def test_subscribe_special_characters(hass, mqtt_mock, calls, record_calls):
"""Test the subscription to topics with special characters.""" """Test the subscription to topics with special characters."""
topic = "/test-topic/$(.)[^]{-}" topic = "/test-topic/$(.)[^]{-}"
payload = "p4y.l[]a|> ?" payload = "p4y.l[]a|> ?"
mqtt.subscribe(self.hass, topic, self.record_calls) await mqtt.async_subscribe(hass, topic, record_calls)
fire_mqtt_message(self.hass, topic, payload) async_fire_mqtt_message(hass, topic, payload)
self.hass.block_till_done() await hass.async_block_till_done()
assert len(self.calls) == 1 assert len(calls) == 1
assert self.calls[0][0].topic == topic assert calls[0][0].topic == topic
assert self.calls[0][0].payload == payload assert calls[0][0].payload == payload
def test_retained_message_on_subscribe_received(self):
async def test_retained_message_on_subscribe_received(
hass, mqtt_client_mock, mqtt_mock
):
"""Test every subscriber receives retained message on subscribe.""" """Test every subscriber receives retained message on subscribe."""
def side_effect(*args): def side_effect(*args):
async_fire_mqtt_message(self.hass, "test/state", "online") async_fire_mqtt_message(hass, "test/state", "online")
return 0, 0 return 0, 0
self.hass.data["mqtt"]._mqttc.subscribe.side_effect = side_effect mqtt_client_mock.subscribe.side_effect = side_effect
# Fake that the client is connected # Fake that the client is connected
self.hass.data["mqtt"].connected = True mqtt_mock.connected = True
calls_a = MagicMock() calls_a = MagicMock()
mqtt.subscribe(self.hass, "test/state", calls_a) await mqtt.async_subscribe(hass, "test/state", calls_a)
self.hass.block_till_done() await hass.async_block_till_done()
assert calls_a.called assert calls_a.called
calls_b = MagicMock() calls_b = MagicMock()
mqtt.subscribe(self.hass, "test/state", calls_b) await mqtt.async_subscribe(hass, "test/state", calls_b)
self.hass.block_till_done() await hass.async_block_till_done()
assert calls_b.called assert calls_b.called
def test_not_calling_unsubscribe_with_active_subscribers(self):
async def test_not_calling_unsubscribe_with_active_subscribers(
hass, mqtt_client_mock, mqtt_mock
):
"""Test not calling unsubscribe() when other subscribers are active.""" """Test not calling unsubscribe() when other subscribers are active."""
# Fake that the client is connected # Fake that the client is connected
self.hass.data["mqtt"].connected = True mqtt_mock.connected = True
unsub = mqtt.subscribe(self.hass, "test/state", None) unsub = await mqtt.async_subscribe(hass, "test/state", None)
mqtt.subscribe(self.hass, "test/state", None) await mqtt.async_subscribe(hass, "test/state", None)
self.hass.block_till_done() await hass.async_block_till_done()
assert self.hass.data["mqtt"]._mqttc.subscribe.called assert mqtt_client_mock.subscribe.called
unsub() unsub()
self.hass.block_till_done() await hass.async_block_till_done()
assert not self.hass.data["mqtt"]._mqttc.unsubscribe.called assert not mqtt_client_mock.unsubscribe.called
def test_restore_subscriptions_on_reconnect(self):
async def test_restore_subscriptions_on_reconnect(hass, mqtt_client_mock, mqtt_mock):
"""Test subscriptions are restored on reconnect.""" """Test subscriptions are restored on reconnect."""
# Fake that the client is connected # Fake that the client is connected
self.hass.data["mqtt"].connected = True mqtt_mock.connected = True
mqtt.subscribe(self.hass, "test/state", None) await mqtt.async_subscribe(hass, "test/state", None)
self.hass.block_till_done() await hass.async_block_till_done()
assert self.hass.data["mqtt"]._mqttc.subscribe.call_count == 1 assert mqtt_client_mock.subscribe.call_count == 1
self.hass.data["mqtt"]._mqtt_on_disconnect(None, None, 0) mqtt_mock._mqtt_on_disconnect(None, None, 0)
self.hass.data["mqtt"]._mqtt_on_connect(None, None, None, 0) mqtt_mock._mqtt_on_connect(None, None, None, 0)
self.hass.block_till_done() await hass.async_block_till_done()
assert self.hass.data["mqtt"]._mqttc.subscribe.call_count == 2 assert mqtt_client_mock.subscribe.call_count == 2
def test_restore_all_active_subscriptions_on_reconnect(self):
async def test_restore_all_active_subscriptions_on_reconnect(
hass, mqtt_client_mock, mqtt_mock
):
"""Test active subscriptions are restored correctly on reconnect.""" """Test active subscriptions are restored correctly on reconnect."""
# Fake that the client is connected # Fake that the client is connected
self.hass.data["mqtt"].connected = True mqtt_mock.connected = True
self.hass.data["mqtt"]._mqttc.subscribe.side_effect = ( mqtt_client_mock.subscribe.side_effect = (
(0, 1), (0, 1),
(0, 2), (0, 2),
(0, 3), (0, 3),
(0, 4), (0, 4),
) )
unsub = mqtt.subscribe(self.hass, "test/state", None, qos=2) unsub = await mqtt.async_subscribe(hass, "test/state", None, qos=2)
mqtt.subscribe(self.hass, "test/state", None) await mqtt.async_subscribe(hass, "test/state", None)
mqtt.subscribe(self.hass, "test/state", None, qos=1) await mqtt.async_subscribe(hass, "test/state", None, qos=1)
self.hass.block_till_done() await hass.async_block_till_done()
expected = [ expected = [
call("test/state", 2), call("test/state", 2),
call("test/state", 0), call("test/state", 0),
call("test/state", 1), call("test/state", 1),
] ]
assert self.hass.data["mqtt"]._mqttc.subscribe.mock_calls == expected assert mqtt_client_mock.subscribe.mock_calls == expected
unsub() unsub()
self.hass.block_till_done() await hass.async_block_till_done()
assert self.hass.data["mqtt"]._mqttc.unsubscribe.call_count == 0 assert mqtt_client_mock.unsubscribe.call_count == 0
self.hass.data["mqtt"]._mqtt_on_disconnect(None, None, 0) mqtt_mock._mqtt_on_disconnect(None, None, 0)
self.hass.data["mqtt"]._mqtt_on_connect(None, None, None, 0) mqtt_mock._mqtt_on_connect(None, None, None, 0)
self.hass.block_till_done() await hass.async_block_till_done()
expected.append(call("test/state", 1)) expected.append(call("test/state", 1))
assert self.hass.data["mqtt"]._mqttc.subscribe.mock_calls == expected assert mqtt_client_mock.subscribe.mock_calls == expected
async def test_setup_embedded_starts_with_no_config(hass): @pytest.fixture
"""Test setting up embedded server with no config.""" def mqtt_server_start_mock(hass):
"""Mock embedded server start."""
client_config = ("localhost", 1883, "user", "pass", None, "3.1.1") client_config = ("localhost", 1883, "user", "pass", None, "3.1.1")
with patch( with patch(
"homeassistant.components.mqtt.server.async_start", "homeassistant.components.mqtt.server.async_start",
return_value=(True, client_config), return_value=(True, client_config),
) as _start: ) as _start:
await async_mock_mqtt_client(hass, {}) yield _start
assert _start.call_count == 1
async def test_setup_embedded_with_embedded(hass): @pytest.mark.parametrize("mqtt_config", [{}])
async def test_setup_embedded_starts_with_no_config(
hass, mqtt_server_start_mock, mqtt_mock
):
"""Test setting up embedded server with no config.""" """Test setting up embedded server with no config."""
client_config = ("localhost", 1883, "user", "pass", None, "3.1.1") assert mqtt_server_start_mock.call_count == 1
with patch(
"homeassistant.components.mqtt.server.async_start", @pytest.mark.parametrize("mqtt_config", [{"embedded": None}])
return_value=(True, client_config), async def test_setup_embedded_with_embedded(hass, mqtt_server_start_mock, mqtt_mock):
) as _start: """Test setting up embedded server with empty embedded config."""
await async_mock_mqtt_client(hass, {"embedded": None}) assert mqtt_server_start_mock.call_count == 1
assert _start.call_count == 1
async def test_setup_logs_error_if_no_connect_broker(hass, caplog): async def test_setup_logs_error_if_no_connect_broker(hass, caplog):
@ -776,42 +793,40 @@ async def test_setup_with_tls_config_of_v1_under_python36_only_uses_v1(hass, moc
assert mock_mqtt.mock_calls[0][2]["tls_version"] == ssl.PROTOCOL_TLSv1 assert mock_mqtt.mock_calls[0][2]["tls_version"] == ssl.PROTOCOL_TLSv1
async def test_birth_message(hass): @pytest.mark.parametrize(
"""Test sending birth message.""" "mqtt_config",
mqtt_client = await async_mock_mqtt_client( [
hass,
{ {
mqtt.CONF_BROKER: "mock-broker", mqtt.CONF_BROKER: "mock-broker",
mqtt.CONF_BIRTH_MESSAGE: { mqtt.CONF_BIRTH_MESSAGE: {
mqtt.ATTR_TOPIC: "birth", mqtt.ATTR_TOPIC: "birth",
mqtt.ATTR_PAYLOAD: "birth", mqtt.ATTR_PAYLOAD: "birth",
}, },
}, }
],
) )
async def test_birth_message(hass, mqtt_client_mock, mqtt_mock):
"""Test sending birth message."""
calls = [] calls = []
mqtt_client.publish.side_effect = lambda *args: calls.append(args) mqtt_client_mock.publish.side_effect = lambda *args: calls.append(args)
hass.data["mqtt"]._mqtt_on_connect(None, None, 0, 0) mqtt_mock._mqtt_on_connect(None, None, 0, 0)
await hass.async_block_till_done() await hass.async_block_till_done()
assert calls[-1] == ("birth", "birth", 0, False) assert calls[-1] == ("birth", "birth", 0, False)
async def test_mqtt_subscribes_topics_on_connect(hass): async def test_mqtt_subscribes_topics_on_connect(hass, mqtt_client_mock, mqtt_mock):
"""Test subscription to topic on connect.""" """Test subscription to topic on connect."""
mqtt_client = await async_mock_mqtt_client(hass) await mqtt.async_subscribe(hass, "topic/test", None)
await mqtt.async_subscribe(hass, "home/sensor", None, 2)
hass.data["mqtt"].subscriptions = [ await mqtt.async_subscribe(hass, "still/pending", None)
mqtt.Subscription("topic/test", None), await mqtt.async_subscribe(hass, "still/pending", None, 1)
mqtt.Subscription("home/sensor", None, 2),
mqtt.Subscription("still/pending", None),
mqtt.Subscription("still/pending", None, 1),
]
hass.add_job = MagicMock() hass.add_job = MagicMock()
hass.data["mqtt"]._mqtt_on_connect(None, None, 0, 0) mqtt_mock._mqtt_on_connect(None, None, 0, 0)
await hass.async_block_till_done() await hass.async_block_till_done()
assert mqtt_client.disconnect.call_count == 0 assert mqtt_client_mock.disconnect.call_count == 0
expected = {"topic/test": 0, "home/sensor": 2, "still/pending": 1} expected = {"topic/test": 0, "home/sensor": 2, "still/pending": 1}
calls = {call[1][1]: call[1][2] for call in hass.add_job.mock_calls} calls = {call[1][1]: call[1][2] for call in hass.add_job.mock_calls}
@ -824,9 +839,8 @@ async def test_setup_fails_without_config(hass):
@pytest.mark.no_fail_on_log_exception @pytest.mark.no_fail_on_log_exception
async def test_message_callback_exception_gets_logged(hass, caplog): async def test_message_callback_exception_gets_logged(hass, caplog, mqtt_mock):
"""Test exception raised by message handler.""" """Test exception raised by message handler."""
await async_mock_mqtt_component(hass)
@callback @callback
def bad_handler(*args): def bad_handler(*args):
@ -843,10 +857,8 @@ async def test_message_callback_exception_gets_logged(hass, caplog):
) )
async def test_mqtt_ws_subscription(hass, hass_ws_client): async def test_mqtt_ws_subscription(hass, hass_ws_client, mqtt_mock):
"""Test MQTT websocket subscription.""" """Test MQTT websocket subscription."""
await async_mock_mqtt_component(hass)
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
await client.send_json({"id": 5, "type": "mqtt/subscribe", "topic": "test-topic"}) await client.send_json({"id": 5, "type": "mqtt/subscribe", "topic": "test-topic"})
response = await client.receive_json() response = await client.receive_json()
@ -869,10 +881,8 @@ async def test_mqtt_ws_subscription(hass, hass_ws_client):
assert response["success"] assert response["success"]
async def test_dump_service(hass): async def test_dump_service(hass, mqtt_mock):
"""Test that we can dump a topic.""" """Test that we can dump a topic."""
await async_mock_mqtt_component(hass)
mopen = mock_open() mopen = mock_open()
await hass.services.async_call( await hass.services.async_call(