diff --git a/homeassistant/components/mqtt/__init__.py b/homeassistant/components/mqtt/__init__.py index bab36ee854e..54f745d5bb2 100644 --- a/homeassistant/components/mqtt/__init__.py +++ b/homeassistant/components/mqtt/__init__.py @@ -366,11 +366,19 @@ def async_publish( @bind_hass def publish_template( hass: HomeAssistantType, topic, payload_template, qos=None, retain=None +) -> None: + """Publish message to an MQTT topic.""" + hass.add_job(async_publish_template, hass, topic, payload_template, qos, retain) + + +@bind_hass +def async_publish_template( + hass: HomeAssistantType, topic, payload_template, qos=None, retain=None ) -> None: """Publish message to an MQTT topic using a template payload.""" data = _build_publish_data(topic, qos, retain) data[ATTR_PAYLOAD_TEMPLATE] = payload_template - hass.services.call(DOMAIN, SERVICE_PUBLISH, data) + hass.async_create_task(hass.services.async_call(DOMAIN, SERVICE_PUBLISH, data)) def wrap_msg_callback(msg_callback: MessageCallbackType) -> MessageCallbackType: diff --git a/tests/common.py b/tests/common.py index 4e457496dee..93d70f0e12c 100644 --- a/tests/common.py +++ b/tests/common.py @@ -344,10 +344,13 @@ async def async_mock_mqtt_component(hass, config=None): assert result await hass.async_block_till_done() - hass.data["mqtt"] = MagicMock( + mqtt_component_mock = MagicMock( spec_set=hass.data["mqtt"], wraps=hass.data["mqtt"] ) + hass.data["mqtt"].connected = mqtt_component_mock.connected + mqtt_component_mock._mqttc = mock_client + hass.data["mqtt"] = mqtt_component_mock return hass.data["mqtt"] diff --git a/tests/components/mqtt/conftest.py b/tests/components/mqtt/conftest.py index 290682549f5..2c5eaf3e54e 100644 --- a/tests/components/mqtt/conftest.py +++ b/tests/components/mqtt/conftest.py @@ -1,12 +1,52 @@ """Test fixtures for mqtt component.""" import pytest -from tests.common import async_mock_mqtt_component +from homeassistant import core as ha +from homeassistant.components import mqtt +from homeassistant.setup import async_setup_component + +from tests.async_mock import MagicMock, patch +from tests.common import async_fire_mqtt_message @pytest.fixture -def mqtt_mock(loop, hass): - """Fixture to mock MQTT.""" - client = loop.run_until_complete(async_mock_mqtt_component(hass)) - client.reset_mock() - return client +def mqtt_config(): + """Fixture to allow overriding MQTT config.""" + return None + + +@pytest.fixture +def mqtt_client_mock(hass): + """Fixture to mock MQTT client.""" + + @ha.callback + def _async_fire_mqtt_message(topic, payload, qos, retain): + async_fire_mqtt_message(hass, topic, payload, qos, retain) + + with patch("paho.mqtt.client.Client") as mock_client: + mock_client = mock_client.return_value + mock_client.connect.return_value = 0 + mock_client.subscribe.return_value = (0, 0) + mock_client.unsubscribe.return_value = (0, 0) + mock_client.publish.side_effect = _async_fire_mqtt_message + yield mock_client + + +@pytest.fixture +async def mqtt_mock(hass, mqtt_client_mock, mqtt_config): + """Fixture to mock MQTT component.""" + if mqtt_config is None: + mqtt_config = {mqtt.CONF_BROKER: "mock-broker"} + + result = await async_setup_component(hass, mqtt.DOMAIN, {mqtt.DOMAIN: mqtt_config}) + assert result + await hass.async_block_till_done() + + mqtt_component_mock = MagicMock(spec_set=hass.data["mqtt"], wraps=hass.data["mqtt"]) + hass.data["mqtt"].connected = mqtt_component_mock.connected + mqtt_component_mock._mqttc = mqtt_client_mock + + hass.data["mqtt"] = mqtt_component_mock + component = hass.data["mqtt"] + component.reset_mock() + return component diff --git a/tests/components/mqtt/test_init.py b/tests/components/mqtt/test_init.py index 369cf76a5e2..89b5a7423f8 100644 --- a/tests/components/mqtt/test_init.py +++ b/tests/components/mqtt/test_init.py @@ -2,7 +2,6 @@ from datetime import datetime, timedelta import json import ssl -import unittest import pytest import voluptuous as vol @@ -27,13 +26,8 @@ from tests.common import ( MockConfigEntry, async_fire_mqtt_message, async_fire_time_changed, - async_mock_mqtt_component, - fire_mqtt_message, - get_test_home_assistant, mock_device_registry, - mock_mqtt_component, mock_registry, - threadsafe_coroutine_factory, ) from tests.testing_config.custom_components.test.sensor import DEVICE_CLASSES @@ -64,619 +58,642 @@ def mock_mqtt(): yield mock_mqtt -async def async_mock_mqtt_client(hass, config=None): - """Mock the MQTT paho client.""" - if config is None: - config = {mqtt.CONF_BROKER: "mock-broker"} - - with patch("paho.mqtt.client.Client") as mock_client: - mock_client().connect.return_value = 0 - mock_client().subscribe.return_value = (0, 0) - mock_client().unsubscribe.return_value = (0, 0) - mock_client().publish.return_value = (0, 0) - result = await async_setup_component(hass, mqtt.DOMAIN, {mqtt.DOMAIN: config}) - assert result - await hass.async_block_till_done() - return mock_client() +@pytest.fixture +def calls(): + """Fixture to record calls.""" + return [] -mock_mqtt_client = threadsafe_coroutine_factory(async_mock_mqtt_client) - - -# pylint: disable=invalid-name -class TestMQTTComponent(unittest.TestCase): - """Test the MQTT component.""" - - def setUp(self): # pylint: disable=invalid-name - """Set up things to be run when tests are started.""" - self.hass = get_test_home_assistant() - mock_mqtt_component(self.hass) - self.calls = [] - self.addCleanup(self.tear_down_cleanup) - - def tear_down_cleanup(self): - """Stop everything that was started.""" - self.hass.stop() +@pytest.fixture +def record_calls(calls): + """Fixture to record calls.""" @callback - def record_calls(self, *args): + def record_calls(*args): """Record calls.""" - self.calls.append(args) + calls.append(args) - def aiohttp_client_stops_on_home_assistant_start(self): - """Test if client stops on HA stop.""" - self.hass.bus.fire(EVENT_HOMEASSISTANT_STOP) - self.hass.block_till_done() - assert self.hass.data["mqtt"].async_disconnect.called + return record_calls - def test_publish_calls_service(self): - """Test the publishing of call to services.""" - self.hass.bus.listen_once(EVENT_CALL_SERVICE, self.record_calls) - mqtt.publish(self.hass, "test-topic", "test-payload") +async def test_mqtt_connects_on_home_assistant_mqtt_setup( + hass, mqtt_client_mock, mqtt_mock +): + """Test if client is connected after mqtt init on bootstrap.""" + assert mqtt_client_mock.connect.call_count == 1 - self.hass.block_till_done() - assert len(self.calls) == 1 - assert self.calls[0][0].data["service_data"][mqtt.ATTR_TOPIC] == "test-topic" - assert ( - self.calls[0][0].data["service_data"][mqtt.ATTR_PAYLOAD] == "test-payload" - ) +async def test_mqtt_disconnects_on_home_assistant_stop(hass, mqtt_mock): + """Test if client stops on HA stop.""" + hass.bus.fire(EVENT_HOMEASSISTANT_STOP) + await hass.async_block_till_done() + await hass.async_block_till_done() + assert mqtt_mock.async_disconnect.called - def test_service_call_without_topic_does_not_publish(self): - """Test the service call if topic is missing.""" - self.hass.bus.fire( - EVENT_CALL_SERVICE, - {ATTR_DOMAIN: mqtt.DOMAIN, ATTR_SERVICE: mqtt.SERVICE_PUBLISH}, - ) - self.hass.block_till_done() - assert not self.hass.data["mqtt"].async_publish.called - def test_service_call_with_template_payload_renders_template(self): - """Test the service call with rendered template. +async def test_publish_calls_service(hass, mqtt_mock, calls, record_calls): + """Test the publishing of call to services.""" + hass.bus.async_listen_once(EVENT_CALL_SERVICE, record_calls) - If 'payload_template' is provided and 'payload' is not, then render it. - """ - mqtt.publish_template(self.hass, "test/topic", "{{ 1+1 }}") - self.hass.block_till_done() - assert self.hass.data["mqtt"].async_publish.called - assert self.hass.data["mqtt"].async_publish.call_args[0][1] == "2" + mqtt.async_publish(hass, "test-topic", "test-payload") - def test_service_call_with_payload_doesnt_render_template(self): - """Test the service call with unrendered template. + await hass.async_block_till_done() - If both 'payload' and 'payload_template' are provided then fail. - """ - payload = "not a template" - payload_template = "a template" - with pytest.raises(vol.Invalid): - self.hass.services.call( - mqtt.DOMAIN, - mqtt.SERVICE_PUBLISH, - { - mqtt.ATTR_TOPIC: "test/topic", - mqtt.ATTR_PAYLOAD: payload, - mqtt.ATTR_PAYLOAD_TEMPLATE: payload_template, - }, - blocking=True, - ) - assert not self.hass.data["mqtt"].async_publish.called + assert len(calls) == 1 + assert calls[0][0].data["service_data"][mqtt.ATTR_TOPIC] == "test-topic" + assert calls[0][0].data["service_data"][mqtt.ATTR_PAYLOAD] == "test-payload" - def test_service_call_with_ascii_qos_retain_flags(self): - """Test the service call with args that can be misinterpreted. - Empty payload message and ascii formatted qos and retain flags. - """ - self.hass.services.call( +async def test_service_call_without_topic_does_not_publish(hass, mqtt_mock): + """Test the service call if topic is missing.""" + hass.bus.fire( + EVENT_CALL_SERVICE, + {ATTR_DOMAIN: mqtt.DOMAIN, ATTR_SERVICE: mqtt.SERVICE_PUBLISH}, + ) + await hass.async_block_till_done() + assert not mqtt_mock.async_publish.called + + +async def test_service_call_with_template_payload_renders_template(hass, mqtt_mock): + """Test the service call with rendered template. + + If 'payload_template' is provided and 'payload' is not, then render it. + """ + mqtt.async_publish_template(hass, "test/topic", "{{ 1+1 }}") + await hass.async_block_till_done() + assert mqtt_mock.async_publish.called + assert mqtt_mock.async_publish.call_args[0][1] == "2" + + +async def test_service_call_with_payload_doesnt_render_template(hass, mqtt_mock): + """Test the service call with unrendered template. + + If both 'payload' and 'payload_template' are provided then fail. + """ + payload = "not a template" + payload_template = "a template" + with pytest.raises(vol.Invalid): + await hass.services.async_call( mqtt.DOMAIN, mqtt.SERVICE_PUBLISH, { mqtt.ATTR_TOPIC: "test/topic", - mqtt.ATTR_PAYLOAD: "", - mqtt.ATTR_QOS: "2", - mqtt.ATTR_RETAIN: "no", + mqtt.ATTR_PAYLOAD: payload, + mqtt.ATTR_PAYLOAD_TEMPLATE: payload_template, }, blocking=True, ) - assert self.hass.data["mqtt"].async_publish.called - assert self.hass.data["mqtt"].async_publish.call_args[0][2] == 2 - assert not self.hass.data["mqtt"].async_publish.call_args[0][3] + assert not mqtt_mock.async_publish.called - def test_validate_topic(self): - """Test topic name/filter validation.""" - # Invalid UTF-8, must not contain U+D800 to U+DFFF. - with pytest.raises(vol.Invalid): - mqtt.valid_topic("\ud800") - with pytest.raises(vol.Invalid): - mqtt.valid_topic("\udfff") - # Topic MUST NOT be empty - with pytest.raises(vol.Invalid): - mqtt.valid_topic("") - # Topic MUST NOT be longer than 65535 encoded bytes. - with pytest.raises(vol.Invalid): - mqtt.valid_topic("ü" * 32768) - # UTF-8 MUST NOT include null character - with pytest.raises(vol.Invalid): - mqtt.valid_topic("bad\0one") - # Topics "SHOULD NOT" include these special characters - # (not MUST NOT, RFC2119). The receiver MAY close the connection. - mqtt.valid_topic("\u0001") - mqtt.valid_topic("\u001F") - mqtt.valid_topic("\u009F") - mqtt.valid_topic("\u009F") - mqtt.valid_topic("\uffff") +async def test_service_call_with_ascii_qos_retain_flags(hass, mqtt_mock): + """Test the service call with args that can be misinterpreted. - def test_validate_subscribe_topic(self): - """Test invalid subscribe topics.""" - mqtt.valid_subscribe_topic("#") - mqtt.valid_subscribe_topic("sport/#") - with pytest.raises(vol.Invalid): - mqtt.valid_subscribe_topic("sport/#/") - with pytest.raises(vol.Invalid): - mqtt.valid_subscribe_topic("foo/bar#") - with pytest.raises(vol.Invalid): - mqtt.valid_subscribe_topic("foo/#/bar") + Empty payload message and ascii formatted qos and retain flags. + """ + await hass.services.async_call( + mqtt.DOMAIN, + mqtt.SERVICE_PUBLISH, + { + mqtt.ATTR_TOPIC: "test/topic", + mqtt.ATTR_PAYLOAD: "", + mqtt.ATTR_QOS: "2", + mqtt.ATTR_RETAIN: "no", + }, + blocking=True, + ) + assert mqtt_mock.async_publish.called + assert mqtt_mock.async_publish.call_args[0][2] == 2 + assert not mqtt_mock.async_publish.call_args[0][3] - mqtt.valid_subscribe_topic("+") - mqtt.valid_subscribe_topic("+/tennis/#") - with pytest.raises(vol.Invalid): - mqtt.valid_subscribe_topic("sport+") - with pytest.raises(vol.Invalid): - mqtt.valid_subscribe_topic("sport+/") - with pytest.raises(vol.Invalid): - mqtt.valid_subscribe_topic("sport/+1") - with pytest.raises(vol.Invalid): - mqtt.valid_subscribe_topic("sport/+#") - with pytest.raises(vol.Invalid): - mqtt.valid_subscribe_topic("bad+topic") - mqtt.valid_subscribe_topic("sport/+/player1") - mqtt.valid_subscribe_topic("/finance") - mqtt.valid_subscribe_topic("+/+") - mqtt.valid_subscribe_topic("$SYS/#") - def test_validate_publish_topic(self): - """Test invalid publish topics.""" - with pytest.raises(vol.Invalid): - mqtt.valid_publish_topic("pub+") - with pytest.raises(vol.Invalid): - mqtt.valid_publish_topic("pub/+") - with pytest.raises(vol.Invalid): - mqtt.valid_publish_topic("1#") - with pytest.raises(vol.Invalid): - mqtt.valid_publish_topic("bad+topic") - mqtt.valid_publish_topic("//") +def test_validate_topic(): + """Test topic name/filter validation.""" + # Invalid UTF-8, must not contain U+D800 to U+DFFF. + with pytest.raises(vol.Invalid): + mqtt.valid_topic("\ud800") + with pytest.raises(vol.Invalid): + mqtt.valid_topic("\udfff") + # Topic MUST NOT be empty + with pytest.raises(vol.Invalid): + mqtt.valid_topic("") + # Topic MUST NOT be longer than 65535 encoded bytes. + with pytest.raises(vol.Invalid): + mqtt.valid_topic("ü" * 32768) + # UTF-8 MUST NOT include null character + with pytest.raises(vol.Invalid): + mqtt.valid_topic("bad\0one") - # Topic names beginning with $ SHOULD NOT be used, but can - mqtt.valid_publish_topic("$SYS/") + # Topics "SHOULD NOT" include these special characters + # (not MUST NOT, RFC2119). The receiver MAY close the connection. + mqtt.valid_topic("\u0001") + mqtt.valid_topic("\u001F") + mqtt.valid_topic("\u009F") + mqtt.valid_topic("\u009F") + mqtt.valid_topic("\uffff") - def test_entity_device_info_schema(self): - """Test MQTT entity device info validation.""" - # just identifier - mqtt.MQTT_ENTITY_DEVICE_INFO_SCHEMA({"identifiers": ["abcd"]}) - mqtt.MQTT_ENTITY_DEVICE_INFO_SCHEMA({"identifiers": "abcd"}) - # just connection - mqtt.MQTT_ENTITY_DEVICE_INFO_SCHEMA( - {"connections": [["mac", "02:5b:26:a8:dc:12"]]} - ) - # full device info + +def test_validate_subscribe_topic(): + """Test invalid subscribe topics.""" + mqtt.valid_subscribe_topic("#") + mqtt.valid_subscribe_topic("sport/#") + with pytest.raises(vol.Invalid): + mqtt.valid_subscribe_topic("sport/#/") + with pytest.raises(vol.Invalid): + mqtt.valid_subscribe_topic("foo/bar#") + with pytest.raises(vol.Invalid): + mqtt.valid_subscribe_topic("foo/#/bar") + + mqtt.valid_subscribe_topic("+") + mqtt.valid_subscribe_topic("+/tennis/#") + with pytest.raises(vol.Invalid): + mqtt.valid_subscribe_topic("sport+") + with pytest.raises(vol.Invalid): + mqtt.valid_subscribe_topic("sport+/") + with pytest.raises(vol.Invalid): + mqtt.valid_subscribe_topic("sport/+1") + with pytest.raises(vol.Invalid): + mqtt.valid_subscribe_topic("sport/+#") + with pytest.raises(vol.Invalid): + mqtt.valid_subscribe_topic("bad+topic") + mqtt.valid_subscribe_topic("sport/+/player1") + mqtt.valid_subscribe_topic("/finance") + mqtt.valid_subscribe_topic("+/+") + mqtt.valid_subscribe_topic("$SYS/#") + + +def test_validate_publish_topic(): + """Test invalid publish topics.""" + with pytest.raises(vol.Invalid): + mqtt.valid_publish_topic("pub+") + with pytest.raises(vol.Invalid): + mqtt.valid_publish_topic("pub/+") + with pytest.raises(vol.Invalid): + mqtt.valid_publish_topic("1#") + with pytest.raises(vol.Invalid): + mqtt.valid_publish_topic("bad+topic") + mqtt.valid_publish_topic("//") + + # Topic names beginning with $ SHOULD NOT be used, but can + mqtt.valid_publish_topic("$SYS/") + + +def test_entity_device_info_schema(): + """Test MQTT entity device info validation.""" + # just identifier + mqtt.MQTT_ENTITY_DEVICE_INFO_SCHEMA({"identifiers": ["abcd"]}) + mqtt.MQTT_ENTITY_DEVICE_INFO_SCHEMA({"identifiers": "abcd"}) + # just connection + mqtt.MQTT_ENTITY_DEVICE_INFO_SCHEMA({"connections": [["mac", "02:5b:26:a8:dc:12"]]}) + # full device info + mqtt.MQTT_ENTITY_DEVICE_INFO_SCHEMA( + { + "identifiers": ["helloworld", "hello"], + "connections": [["mac", "02:5b:26:a8:dc:12"], ["zigbee", "zigbee_id"]], + "manufacturer": "Whatever", + "name": "Beer", + "model": "Glass", + "sw_version": "0.1-beta", + } + ) + # full device info with via_device + mqtt.MQTT_ENTITY_DEVICE_INFO_SCHEMA( + { + "identifiers": ["helloworld", "hello"], + "connections": [["mac", "02:5b:26:a8:dc:12"], ["zigbee", "zigbee_id"]], + "manufacturer": "Whatever", + "name": "Beer", + "model": "Glass", + "sw_version": "0.1-beta", + "via_device": "test-hub", + } + ) + # no identifiers + with pytest.raises(vol.Invalid): mqtt.MQTT_ENTITY_DEVICE_INFO_SCHEMA( { - "identifiers": ["helloworld", "hello"], - "connections": [["mac", "02:5b:26:a8:dc:12"], ["zigbee", "zigbee_id"]], "manufacturer": "Whatever", "name": "Beer", "model": "Glass", "sw_version": "0.1-beta", } ) - # full device info with via_device + # empty identifiers + with pytest.raises(vol.Invalid): mqtt.MQTT_ENTITY_DEVICE_INFO_SCHEMA( - { - "identifiers": ["helloworld", "hello"], - "connections": [["mac", "02:5b:26:a8:dc:12"], ["zigbee", "zigbee_id"]], - "manufacturer": "Whatever", - "name": "Beer", - "model": "Glass", - "sw_version": "0.1-beta", - "via_device": "test-hub", - } + {"identifiers": [], "connections": [], "name": "Beer"} ) - # no identifiers - with pytest.raises(vol.Invalid): - mqtt.MQTT_ENTITY_DEVICE_INFO_SCHEMA( - { - "manufacturer": "Whatever", - "name": "Beer", - "model": "Glass", - "sw_version": "0.1-beta", - } - ) - # empty identifiers - with pytest.raises(vol.Invalid): - mqtt.MQTT_ENTITY_DEVICE_INFO_SCHEMA( - {"identifiers": [], "connections": [], "name": "Beer"} - ) -# pylint: disable=invalid-name -class TestMQTTCallbacks(unittest.TestCase): - """Test the MQTT callbacks.""" +async def test_receiving_non_utf8_message_gets_logged( + hass, mqtt_mock, calls, record_calls, caplog +): + """Test receiving a non utf8 encoded message.""" + await mqtt.async_subscribe(hass, "test-topic", record_calls) - def setUp(self): # pylint: disable=invalid-name - """Set up things to be run when tests are started.""" - self.hass = get_test_home_assistant() - mock_mqtt_client(self.hass) - self.calls = [] - self.addCleanup(self.tear_down_cleanup) + async_fire_mqtt_message(hass, "test-topic", b"\x9a") - def tear_down_cleanup(self): - """Stop everything that was started.""" - self.hass.stop() + await hass.async_block_till_done() + assert ( + "Can't decode payload b'\\x9a' on test-topic with encoding utf-8" in caplog.text + ) + + +async def test_all_subscriptions_run_when_decode_fails( + hass, mqtt_mock, calls, record_calls +): + """Test all other subscriptions still run when decode fails for one.""" + await mqtt.async_subscribe(hass, "test-topic", record_calls, encoding="ascii") + await mqtt.async_subscribe(hass, "test-topic", record_calls) + + async_fire_mqtt_message(hass, "test-topic", TEMP_CELSIUS) + + await hass.async_block_till_done() + assert len(calls) == 1 + + +async def test_subscribe_topic(hass, mqtt_mock, calls, record_calls): + """Test the subscription of a topic.""" + unsub = await mqtt.async_subscribe(hass, "test-topic", record_calls) + + async_fire_mqtt_message(hass, "test-topic", "test-payload") + + await hass.async_block_till_done() + assert len(calls) == 1 + assert calls[0][0].topic == "test-topic" + assert calls[0][0].payload == "test-payload" + + unsub() + + async_fire_mqtt_message(hass, "test-topic", "test-payload") + + await hass.async_block_till_done() + assert len(calls) == 1 + + +async def test_subscribe_deprecated(hass, mqtt_mock): + """Test the subscription of a topic using deprecated callback signature.""" + calls = [] @callback - def record_calls(self, *args): + def record_calls(topic, payload, qos): """Record calls.""" - self.calls.append(args) + calls.append((topic, payload, qos)) - def aiohttp_client_starts_on_home_assistant_mqtt_setup(self): - """Test if client is connected after mqtt init on bootstrap.""" - assert self.hass.data["mqtt"]._mqttc.connect.call_count == 1 + unsub = await mqtt.async_subscribe(hass, "test-topic", record_calls) - def test_receiving_non_utf8_message_gets_logged(self): - """Test receiving a non utf8 encoded message.""" - mqtt.subscribe(self.hass, "test-topic", self.record_calls) + async_fire_mqtt_message(hass, "test-topic", "test-payload") - with self.assertLogs(level="WARNING") as test_handle: - fire_mqtt_message(self.hass, "test-topic", b"\x9a") + await hass.async_block_till_done() + assert len(calls) == 1 + assert calls[0][0] == "test-topic" + assert calls[0][1] == "test-payload" - self.hass.block_till_done() - assert ( - "WARNING:homeassistant.components.mqtt:Can't decode payload " - "b'\\x9a' on test-topic with encoding utf-8" in test_handle.output[0] - ) + unsub() - def test_all_subscriptions_run_when_decode_fails(self): - """Test all other subscriptions still run when decode fails for one.""" - mqtt.subscribe(self.hass, "test-topic", self.record_calls, encoding="ascii") - mqtt.subscribe(self.hass, "test-topic", self.record_calls) + async_fire_mqtt_message(hass, "test-topic", "test-payload") - fire_mqtt_message(self.hass, "test-topic", TEMP_CELSIUS) + await hass.async_block_till_done() + assert len(calls) == 1 - self.hass.block_till_done() - assert len(self.calls) == 1 - def test_subscribe_topic(self): - """Test the subscription of a topic.""" - unsub = mqtt.subscribe(self.hass, "test-topic", self.record_calls) +async def test_subscribe_deprecated_async(hass, mqtt_mock): + """Test the subscription of a topic using deprecated callback signature.""" + calls = [] - fire_mqtt_message(self.hass, "test-topic", "test-payload") + @callback + async def record_calls(topic, payload, qos): + """Record calls.""" + calls.append((topic, payload, qos)) - self.hass.block_till_done() - assert len(self.calls) == 1 - assert self.calls[0][0].topic == "test-topic" - assert self.calls[0][0].payload == "test-payload" + unsub = await mqtt.async_subscribe(hass, "test-topic", record_calls) - unsub() + async_fire_mqtt_message(hass, "test-topic", "test-payload") - fire_mqtt_message(self.hass, "test-topic", "test-payload") + await hass.async_block_till_done() + assert len(calls) == 1 + assert calls[0][0] == "test-topic" + assert calls[0][1] == "test-payload" - self.hass.block_till_done() - assert len(self.calls) == 1 + unsub() - def test_subscribe_deprecated(self): - """Test the subscription of a topic using deprecated callback signature.""" - calls = [] + async_fire_mqtt_message(hass, "test-topic", "test-payload") - @callback - def record_calls(topic, payload, qos): - """Record calls.""" - calls.append((topic, payload, qos)) + await hass.async_block_till_done() + assert len(calls) == 1 - unsub = mqtt.subscribe(self.hass, "test-topic", record_calls) - fire_mqtt_message(self.hass, "test-topic", "test-payload") +async def test_subscribe_topic_not_match(hass, mqtt_mock, calls, record_calls): + """Test if subscribed topic is not a match.""" + await mqtt.async_subscribe(hass, "test-topic", record_calls) - self.hass.block_till_done() - assert len(calls) == 1 - assert calls[0][0] == "test-topic" - assert calls[0][1] == "test-payload" + async_fire_mqtt_message(hass, "another-test-topic", "test-payload") - unsub() + await hass.async_block_till_done() + assert len(calls) == 0 - fire_mqtt_message(self.hass, "test-topic", "test-payload") - self.hass.block_till_done() - assert len(calls) == 1 +async def test_subscribe_topic_level_wildcard(hass, mqtt_mock, calls, record_calls): + """Test the subscription of wildcard topics.""" + await mqtt.async_subscribe(hass, "test-topic/+/on", record_calls) - def test_subscribe_deprecated_async(self): - """Test the subscription of a topic using deprecated callback signature.""" - calls = [] + async_fire_mqtt_message(hass, "test-topic/bier/on", "test-payload") - @callback - async def record_calls(topic, payload, qos): - """Record calls.""" - calls.append((topic, payload, qos)) + await hass.async_block_till_done() + assert len(calls) == 1 + assert calls[0][0].topic == "test-topic/bier/on" + assert calls[0][0].payload == "test-payload" - unsub = mqtt.subscribe(self.hass, "test-topic", record_calls) - fire_mqtt_message(self.hass, "test-topic", "test-payload") +async def test_subscribe_topic_level_wildcard_no_subtree_match( + hass, mqtt_mock, calls, record_calls +): + """Test the subscription of wildcard topics.""" + await mqtt.async_subscribe(hass, "test-topic/+/on", record_calls) - self.hass.block_till_done() - assert len(calls) == 1 - assert calls[0][0] == "test-topic" - assert calls[0][1] == "test-payload" + async_fire_mqtt_message(hass, "test-topic/bier", "test-payload") - unsub() + await hass.async_block_till_done() + assert len(calls) == 0 - fire_mqtt_message(self.hass, "test-topic", "test-payload") - self.hass.block_till_done() - assert len(calls) == 1 +async def test_subscribe_topic_level_wildcard_root_topic_no_subtree_match( + hass, mqtt_mock, calls, record_calls +): + """Test the subscription of wildcard topics.""" + await mqtt.async_subscribe(hass, "test-topic/#", record_calls) - def test_subscribe_topic_not_match(self): - """Test if subscribed topic is not a match.""" - mqtt.subscribe(self.hass, "test-topic", self.record_calls) + async_fire_mqtt_message(hass, "test-topic-123", "test-payload") - fire_mqtt_message(self.hass, "another-test-topic", "test-payload") + await hass.async_block_till_done() + assert len(calls) == 0 - self.hass.block_till_done() - assert len(self.calls) == 0 - def test_subscribe_topic_level_wildcard(self): - """Test the subscription of wildcard topics.""" - mqtt.subscribe(self.hass, "test-topic/+/on", self.record_calls) +async def test_subscribe_topic_subtree_wildcard_subtree_topic( + hass, mqtt_mock, calls, record_calls +): + """Test the subscription of wildcard topics.""" + await mqtt.async_subscribe(hass, "test-topic/#", record_calls) - fire_mqtt_message(self.hass, "test-topic/bier/on", "test-payload") + async_fire_mqtt_message(hass, "test-topic/bier/on", "test-payload") - self.hass.block_till_done() - assert len(self.calls) == 1 - assert self.calls[0][0].topic == "test-topic/bier/on" - assert self.calls[0][0].payload == "test-payload" + await hass.async_block_till_done() + assert len(calls) == 1 + assert calls[0][0].topic == "test-topic/bier/on" + assert calls[0][0].payload == "test-payload" - def test_subscribe_topic_level_wildcard_no_subtree_match(self): - """Test the subscription of wildcard topics.""" - mqtt.subscribe(self.hass, "test-topic/+/on", self.record_calls) - fire_mqtt_message(self.hass, "test-topic/bier", "test-payload") +async def test_subscribe_topic_subtree_wildcard_root_topic( + hass, mqtt_mock, calls, record_calls +): + """Test the subscription of wildcard topics.""" + await mqtt.async_subscribe(hass, "test-topic/#", record_calls) - self.hass.block_till_done() - assert len(self.calls) == 0 + async_fire_mqtt_message(hass, "test-topic", "test-payload") - def test_subscribe_topic_level_wildcard_root_topic_no_subtree_match(self): - """Test the subscription of wildcard topics.""" - mqtt.subscribe(self.hass, "test-topic/#", self.record_calls) + await hass.async_block_till_done() + assert len(calls) == 1 + assert calls[0][0].topic == "test-topic" + assert calls[0][0].payload == "test-payload" - fire_mqtt_message(self.hass, "test-topic-123", "test-payload") - self.hass.block_till_done() - assert len(self.calls) == 0 +async def test_subscribe_topic_subtree_wildcard_no_match( + hass, mqtt_mock, calls, record_calls +): + """Test the subscription of wildcard topics.""" + await mqtt.async_subscribe(hass, "test-topic/#", record_calls) - def test_subscribe_topic_subtree_wildcard_subtree_topic(self): - """Test the subscription of wildcard topics.""" - mqtt.subscribe(self.hass, "test-topic/#", self.record_calls) + async_fire_mqtt_message(hass, "another-test-topic", "test-payload") - fire_mqtt_message(self.hass, "test-topic/bier/on", "test-payload") + await hass.async_block_till_done() + assert len(calls) == 0 - self.hass.block_till_done() - assert len(self.calls) == 1 - assert self.calls[0][0].topic == "test-topic/bier/on" - assert self.calls[0][0].payload == "test-payload" - def test_subscribe_topic_subtree_wildcard_root_topic(self): - """Test the subscription of wildcard topics.""" - mqtt.subscribe(self.hass, "test-topic/#", self.record_calls) +async def test_subscribe_topic_level_wildcard_and_wildcard_root_topic( + hass, mqtt_mock, calls, record_calls +): + """Test the subscription of wildcard topics.""" + await mqtt.async_subscribe(hass, "+/test-topic/#", record_calls) - fire_mqtt_message(self.hass, "test-topic", "test-payload") + async_fire_mqtt_message(hass, "hi/test-topic", "test-payload") - self.hass.block_till_done() - assert len(self.calls) == 1 - assert self.calls[0][0].topic == "test-topic" - assert self.calls[0][0].payload == "test-payload" + await hass.async_block_till_done() + assert len(calls) == 1 + assert calls[0][0].topic == "hi/test-topic" + assert calls[0][0].payload == "test-payload" - def test_subscribe_topic_subtree_wildcard_no_match(self): - """Test the subscription of wildcard topics.""" - mqtt.subscribe(self.hass, "test-topic/#", self.record_calls) - fire_mqtt_message(self.hass, "another-test-topic", "test-payload") +async def test_subscribe_topic_level_wildcard_and_wildcard_subtree_topic( + hass, mqtt_mock, calls, record_calls +): + """Test the subscription of wildcard topics.""" + await mqtt.async_subscribe(hass, "+/test-topic/#", record_calls) - self.hass.block_till_done() - assert len(self.calls) == 0 + async_fire_mqtt_message(hass, "hi/test-topic/here-iam", "test-payload") - def test_subscribe_topic_level_wildcard_and_wildcard_root_topic(self): - """Test the subscription of wildcard topics.""" - mqtt.subscribe(self.hass, "+/test-topic/#", self.record_calls) + await hass.async_block_till_done() + assert len(calls) == 1 + assert calls[0][0].topic == "hi/test-topic/here-iam" + assert calls[0][0].payload == "test-payload" - fire_mqtt_message(self.hass, "hi/test-topic", "test-payload") - self.hass.block_till_done() - assert len(self.calls) == 1 - assert self.calls[0][0].topic == "hi/test-topic" - assert self.calls[0][0].payload == "test-payload" +async def test_subscribe_topic_level_wildcard_and_wildcard_level_no_match( + hass, mqtt_mock, calls, record_calls +): + """Test the subscription of wildcard topics.""" + await mqtt.async_subscribe(hass, "+/test-topic/#", record_calls) - def test_subscribe_topic_level_wildcard_and_wildcard_subtree_topic(self): - """Test the subscription of wildcard topics.""" - mqtt.subscribe(self.hass, "+/test-topic/#", self.record_calls) + async_fire_mqtt_message(hass, "hi/here-iam/test-topic", "test-payload") - fire_mqtt_message(self.hass, "hi/test-topic/here-iam", "test-payload") + await hass.async_block_till_done() + assert len(calls) == 0 - self.hass.block_till_done() - assert len(self.calls) == 1 - assert self.calls[0][0].topic == "hi/test-topic/here-iam" - assert self.calls[0][0].payload == "test-payload" - def test_subscribe_topic_level_wildcard_and_wildcard_level_no_match(self): - """Test the subscription of wildcard topics.""" - mqtt.subscribe(self.hass, "+/test-topic/#", self.record_calls) +async def test_subscribe_topic_level_wildcard_and_wildcard_no_match( + hass, mqtt_mock, calls, record_calls +): + """Test the subscription of wildcard topics.""" + await mqtt.async_subscribe(hass, "+/test-topic/#", record_calls) - fire_mqtt_message(self.hass, "hi/here-iam/test-topic", "test-payload") + async_fire_mqtt_message(hass, "hi/another-test-topic", "test-payload") - self.hass.block_till_done() - assert len(self.calls) == 0 + await hass.async_block_till_done() + assert len(calls) == 0 - def test_subscribe_topic_level_wildcard_and_wildcard_no_match(self): - """Test the subscription of wildcard topics.""" - mqtt.subscribe(self.hass, "+/test-topic/#", self.record_calls) - fire_mqtt_message(self.hass, "hi/another-test-topic", "test-payload") +async def test_subscribe_topic_sys_root(hass, mqtt_mock, calls, record_calls): + """Test the subscription of $ root topics.""" + await mqtt.async_subscribe(hass, "$test-topic/subtree/on", record_calls) - self.hass.block_till_done() - assert len(self.calls) == 0 + async_fire_mqtt_message(hass, "$test-topic/subtree/on", "test-payload") - def test_subscribe_topic_sys_root(self): - """Test the subscription of $ root topics.""" - mqtt.subscribe(self.hass, "$test-topic/subtree/on", self.record_calls) + await hass.async_block_till_done() + assert len(calls) == 1 + assert calls[0][0].topic == "$test-topic/subtree/on" + assert calls[0][0].payload == "test-payload" - fire_mqtt_message(self.hass, "$test-topic/subtree/on", "test-payload") - self.hass.block_till_done() - assert len(self.calls) == 1 - assert self.calls[0][0].topic == "$test-topic/subtree/on" - assert self.calls[0][0].payload == "test-payload" +async def test_subscribe_topic_sys_root_and_wildcard_topic( + hass, mqtt_mock, calls, record_calls +): + """Test the subscription of $ root and wildcard topics.""" + await mqtt.async_subscribe(hass, "$test-topic/#", record_calls) - def test_subscribe_topic_sys_root_and_wildcard_topic(self): - """Test the subscription of $ root and wildcard topics.""" - mqtt.subscribe(self.hass, "$test-topic/#", self.record_calls) + async_fire_mqtt_message(hass, "$test-topic/some-topic", "test-payload") - fire_mqtt_message(self.hass, "$test-topic/some-topic", "test-payload") + await hass.async_block_till_done() + assert len(calls) == 1 + assert calls[0][0].topic == "$test-topic/some-topic" + assert calls[0][0].payload == "test-payload" - self.hass.block_till_done() - assert len(self.calls) == 1 - assert self.calls[0][0].topic == "$test-topic/some-topic" - assert self.calls[0][0].payload == "test-payload" - def test_subscribe_topic_sys_root_and_wildcard_subtree_topic(self): - """Test the subscription of $ root and wildcard subtree topics.""" - mqtt.subscribe(self.hass, "$test-topic/subtree/#", self.record_calls) +async def test_subscribe_topic_sys_root_and_wildcard_subtree_topic( + hass, mqtt_mock, calls, record_calls +): + """Test the subscription of $ root and wildcard subtree topics.""" + await mqtt.async_subscribe(hass, "$test-topic/subtree/#", record_calls) - fire_mqtt_message(self.hass, "$test-topic/subtree/some-topic", "test-payload") + async_fire_mqtt_message(hass, "$test-topic/subtree/some-topic", "test-payload") - self.hass.block_till_done() - assert len(self.calls) == 1 - assert self.calls[0][0].topic == "$test-topic/subtree/some-topic" - assert self.calls[0][0].payload == "test-payload" + await hass.async_block_till_done() + assert len(calls) == 1 + assert calls[0][0].topic == "$test-topic/subtree/some-topic" + assert calls[0][0].payload == "test-payload" - def test_subscribe_special_characters(self): - """Test the subscription to topics with special characters.""" - topic = "/test-topic/$(.)[^]{-}" - payload = "p4y.l[]a|> ?" - mqtt.subscribe(self.hass, topic, self.record_calls) +async def test_subscribe_special_characters(hass, mqtt_mock, calls, record_calls): + """Test the subscription to topics with special characters.""" + topic = "/test-topic/$(.)[^]{-}" + payload = "p4y.l[]a|> ?" - fire_mqtt_message(self.hass, topic, payload) - self.hass.block_till_done() - assert len(self.calls) == 1 - assert self.calls[0][0].topic == topic - assert self.calls[0][0].payload == payload + await mqtt.async_subscribe(hass, topic, record_calls) - def test_retained_message_on_subscribe_received(self): - """Test every subscriber receives retained message on subscribe.""" + async_fire_mqtt_message(hass, topic, payload) + await hass.async_block_till_done() + assert len(calls) == 1 + assert calls[0][0].topic == topic + assert calls[0][0].payload == payload - def side_effect(*args): - async_fire_mqtt_message(self.hass, "test/state", "online") - return 0, 0 - self.hass.data["mqtt"]._mqttc.subscribe.side_effect = side_effect +async def test_retained_message_on_subscribe_received( + hass, mqtt_client_mock, mqtt_mock +): + """Test every subscriber receives retained message on subscribe.""" - # Fake that the client is connected - self.hass.data["mqtt"].connected = True + def side_effect(*args): + async_fire_mqtt_message(hass, "test/state", "online") + return 0, 0 - calls_a = MagicMock() - mqtt.subscribe(self.hass, "test/state", calls_a) - self.hass.block_till_done() - assert calls_a.called + mqtt_client_mock.subscribe.side_effect = side_effect - calls_b = MagicMock() - mqtt.subscribe(self.hass, "test/state", calls_b) - self.hass.block_till_done() - assert calls_b.called + # Fake that the client is connected + mqtt_mock.connected = True - def test_not_calling_unsubscribe_with_active_subscribers(self): - """Test not calling unsubscribe() when other subscribers are active.""" - # Fake that the client is connected - self.hass.data["mqtt"].connected = True + calls_a = MagicMock() + await mqtt.async_subscribe(hass, "test/state", calls_a) + await hass.async_block_till_done() + assert calls_a.called - unsub = mqtt.subscribe(self.hass, "test/state", None) - mqtt.subscribe(self.hass, "test/state", None) - self.hass.block_till_done() - assert self.hass.data["mqtt"]._mqttc.subscribe.called + calls_b = MagicMock() + await mqtt.async_subscribe(hass, "test/state", calls_b) + await hass.async_block_till_done() + assert calls_b.called - unsub() - self.hass.block_till_done() - assert not self.hass.data["mqtt"]._mqttc.unsubscribe.called - def test_restore_subscriptions_on_reconnect(self): - """Test subscriptions are restored on reconnect.""" - # Fake that the client is connected - self.hass.data["mqtt"].connected = True +async def test_not_calling_unsubscribe_with_active_subscribers( + hass, mqtt_client_mock, mqtt_mock +): + """Test not calling unsubscribe() when other subscribers are active.""" + # Fake that the client is connected + mqtt_mock.connected = True - mqtt.subscribe(self.hass, "test/state", None) - self.hass.block_till_done() - assert self.hass.data["mqtt"]._mqttc.subscribe.call_count == 1 + unsub = await mqtt.async_subscribe(hass, "test/state", None) + await mqtt.async_subscribe(hass, "test/state", None) + await hass.async_block_till_done() + assert mqtt_client_mock.subscribe.called - self.hass.data["mqtt"]._mqtt_on_disconnect(None, None, 0) - self.hass.data["mqtt"]._mqtt_on_connect(None, None, None, 0) - self.hass.block_till_done() - assert self.hass.data["mqtt"]._mqttc.subscribe.call_count == 2 + unsub() + await hass.async_block_till_done() + assert not mqtt_client_mock.unsubscribe.called - def test_restore_all_active_subscriptions_on_reconnect(self): - """Test active subscriptions are restored correctly on reconnect.""" - # Fake that the client is connected - self.hass.data["mqtt"].connected = True - - self.hass.data["mqtt"]._mqttc.subscribe.side_effect = ( - (0, 1), - (0, 2), - (0, 3), - (0, 4), - ) - - unsub = mqtt.subscribe(self.hass, "test/state", None, qos=2) - mqtt.subscribe(self.hass, "test/state", None) - mqtt.subscribe(self.hass, "test/state", None, qos=1) - self.hass.block_till_done() - - expected = [ - call("test/state", 2), - call("test/state", 0), - call("test/state", 1), - ] - assert self.hass.data["mqtt"]._mqttc.subscribe.mock_calls == expected - - unsub() - self.hass.block_till_done() - assert self.hass.data["mqtt"]._mqttc.unsubscribe.call_count == 0 - - self.hass.data["mqtt"]._mqtt_on_disconnect(None, None, 0) - self.hass.data["mqtt"]._mqtt_on_connect(None, None, None, 0) - self.hass.block_till_done() - - expected.append(call("test/state", 1)) - assert self.hass.data["mqtt"]._mqttc.subscribe.mock_calls == expected - - -async def test_setup_embedded_starts_with_no_config(hass): - """Test setting up embedded server with no config.""" + +async def test_restore_subscriptions_on_reconnect(hass, mqtt_client_mock, mqtt_mock): + """Test subscriptions are restored on reconnect.""" + # Fake that the client is connected + mqtt_mock.connected = True + + await mqtt.async_subscribe(hass, "test/state", None) + await hass.async_block_till_done() + assert mqtt_client_mock.subscribe.call_count == 1 + + mqtt_mock._mqtt_on_disconnect(None, None, 0) + mqtt_mock._mqtt_on_connect(None, None, None, 0) + await hass.async_block_till_done() + assert mqtt_client_mock.subscribe.call_count == 2 + + +async def test_restore_all_active_subscriptions_on_reconnect( + hass, mqtt_client_mock, mqtt_mock +): + """Test active subscriptions are restored correctly on reconnect.""" + # Fake that the client is connected + mqtt_mock.connected = True + + mqtt_client_mock.subscribe.side_effect = ( + (0, 1), + (0, 2), + (0, 3), + (0, 4), + ) + + unsub = await mqtt.async_subscribe(hass, "test/state", None, qos=2) + await mqtt.async_subscribe(hass, "test/state", None) + await mqtt.async_subscribe(hass, "test/state", None, qos=1) + await hass.async_block_till_done() + + expected = [ + call("test/state", 2), + call("test/state", 0), + call("test/state", 1), + ] + assert mqtt_client_mock.subscribe.mock_calls == expected + + unsub() + await hass.async_block_till_done() + assert mqtt_client_mock.unsubscribe.call_count == 0 + + mqtt_mock._mqtt_on_disconnect(None, None, 0) + mqtt_mock._mqtt_on_connect(None, None, None, 0) + await hass.async_block_till_done() + + expected.append(call("test/state", 1)) + assert mqtt_client_mock.subscribe.mock_calls == expected + + +@pytest.fixture +def mqtt_server_start_mock(hass): + """Mock embedded server start.""" client_config = ("localhost", 1883, "user", "pass", None, "3.1.1") with patch( "homeassistant.components.mqtt.server.async_start", return_value=(True, client_config), ) as _start: - await async_mock_mqtt_client(hass, {}) - assert _start.call_count == 1 + yield _start -async def test_setup_embedded_with_embedded(hass): +@pytest.mark.parametrize("mqtt_config", [{}]) +async def test_setup_embedded_starts_with_no_config( + hass, mqtt_server_start_mock, mqtt_mock +): """Test setting up embedded server with no config.""" - client_config = ("localhost", 1883, "user", "pass", None, "3.1.1") + assert mqtt_server_start_mock.call_count == 1 - with patch( - "homeassistant.components.mqtt.server.async_start", - return_value=(True, client_config), - ) as _start: - await async_mock_mqtt_client(hass, {"embedded": None}) - assert _start.call_count == 1 + +@pytest.mark.parametrize("mqtt_config", [{"embedded": None}]) +async def test_setup_embedded_with_embedded(hass, mqtt_server_start_mock, mqtt_mock): + """Test setting up embedded server with empty embedded config.""" + assert mqtt_server_start_mock.call_count == 1 async def test_setup_logs_error_if_no_connect_broker(hass, caplog): @@ -776,42 +793,40 @@ async def test_setup_with_tls_config_of_v1_under_python36_only_uses_v1(hass, moc assert mock_mqtt.mock_calls[0][2]["tls_version"] == ssl.PROTOCOL_TLSv1 -async def test_birth_message(hass): - """Test sending birth message.""" - mqtt_client = await async_mock_mqtt_client( - hass, +@pytest.mark.parametrize( + "mqtt_config", + [ { mqtt.CONF_BROKER: "mock-broker", mqtt.CONF_BIRTH_MESSAGE: { mqtt.ATTR_TOPIC: "birth", mqtt.ATTR_PAYLOAD: "birth", }, - }, - ) + } + ], +) +async def test_birth_message(hass, mqtt_client_mock, mqtt_mock): + """Test sending birth message.""" calls = [] - mqtt_client.publish.side_effect = lambda *args: calls.append(args) - hass.data["mqtt"]._mqtt_on_connect(None, None, 0, 0) + mqtt_client_mock.publish.side_effect = lambda *args: calls.append(args) + mqtt_mock._mqtt_on_connect(None, None, 0, 0) await hass.async_block_till_done() assert calls[-1] == ("birth", "birth", 0, False) -async def test_mqtt_subscribes_topics_on_connect(hass): +async def test_mqtt_subscribes_topics_on_connect(hass, mqtt_client_mock, mqtt_mock): """Test subscription to topic on connect.""" - mqtt_client = await async_mock_mqtt_client(hass) - - hass.data["mqtt"].subscriptions = [ - mqtt.Subscription("topic/test", None), - mqtt.Subscription("home/sensor", None, 2), - mqtt.Subscription("still/pending", None), - mqtt.Subscription("still/pending", None, 1), - ] + await mqtt.async_subscribe(hass, "topic/test", None) + await mqtt.async_subscribe(hass, "home/sensor", None, 2) + await mqtt.async_subscribe(hass, "still/pending", None) + await mqtt.async_subscribe(hass, "still/pending", None, 1) hass.add_job = MagicMock() - hass.data["mqtt"]._mqtt_on_connect(None, None, 0, 0) + mqtt_mock._mqtt_on_connect(None, None, 0, 0) await hass.async_block_till_done() - assert mqtt_client.disconnect.call_count == 0 + assert mqtt_client_mock.disconnect.call_count == 0 expected = {"topic/test": 0, "home/sensor": 2, "still/pending": 1} calls = {call[1][1]: call[1][2] for call in hass.add_job.mock_calls} @@ -824,9 +839,8 @@ async def test_setup_fails_without_config(hass): @pytest.mark.no_fail_on_log_exception -async def test_message_callback_exception_gets_logged(hass, caplog): +async def test_message_callback_exception_gets_logged(hass, caplog, mqtt_mock): """Test exception raised by message handler.""" - await async_mock_mqtt_component(hass) @callback def bad_handler(*args): @@ -843,10 +857,8 @@ async def test_message_callback_exception_gets_logged(hass, caplog): ) -async def test_mqtt_ws_subscription(hass, hass_ws_client): +async def test_mqtt_ws_subscription(hass, hass_ws_client, mqtt_mock): """Test MQTT websocket subscription.""" - await async_mock_mqtt_component(hass) - client = await hass_ws_client(hass) await client.send_json({"id": 5, "type": "mqtt/subscribe", "topic": "test-topic"}) response = await client.receive_json() @@ -869,10 +881,8 @@ async def test_mqtt_ws_subscription(hass, hass_ws_client): assert response["success"] -async def test_dump_service(hass): +async def test_dump_service(hass, mqtt_mock): """Test that we can dump a topic.""" - await async_mock_mqtt_component(hass) - mopen = mock_open() await hass.services.async_call(