diff --git a/tests/components/zha/common.py b/tests/components/zha/common.py index 11237f6cd73..82799e8dd9d 100644 --- a/tests/components/zha/common.py +++ b/tests/components/zha/common.py @@ -34,18 +34,20 @@ class FakeEndpoint: self.device_type = None self.request = AsyncMock(return_value=[0]) - def add_input_cluster(self, cluster_id): + def add_input_cluster(self, cluster_id, _patch_cluster=True): """Add an input cluster.""" cluster = zigpy.zcl.Cluster.from_id(self, cluster_id, is_server=True) - patch_cluster(cluster) + if _patch_cluster: + patch_cluster(cluster) self.in_clusters[cluster_id] = cluster if hasattr(cluster, "ep_attribute"): setattr(self, cluster.ep_attribute, cluster) - def add_output_cluster(self, cluster_id): + def add_output_cluster(self, cluster_id, _patch_cluster=True): """Add an output cluster.""" cluster = zigpy.zcl.Cluster.from_id(self, cluster_id, is_server=False) - patch_cluster(cluster) + if _patch_cluster: + patch_cluster(cluster) self.out_clusters[cluster_id] = cluster reply = AsyncMock(return_value=[0]) diff --git a/tests/components/zha/conftest.py b/tests/components/zha/conftest.py index a538c1b7f3c..a5aa330a813 100644 --- a/tests/components/zha/conftest.py +++ b/tests/components/zha/conftest.py @@ -101,6 +101,7 @@ def zigpy_device_mock(zigpy_app_controller): model="FakeModel", node_descriptor=b"\x02@\x807\x10\x7fd\x00\x00*d\x00\x00", nwk=0xB79C, + patch_cluster=True, ): """Make a fake device using the specified cluster classes.""" device = FakeDevice( @@ -116,10 +117,10 @@ def zigpy_device_mock(zigpy_app_controller): endpoint.profile_id = profile_id for cluster_id in ep.get("in_clusters", []): - endpoint.add_input_cluster(cluster_id) + endpoint.add_input_cluster(cluster_id, _patch_cluster=patch_cluster) for cluster_id in ep.get("out_clusters", []): - endpoint.add_output_cluster(cluster_id) + endpoint.add_output_cluster(cluster_id, _patch_cluster=patch_cluster) return device @@ -187,6 +188,7 @@ def zha_device_mock(hass, zigpy_device_mock): manufacturer="mock manufacturer", model="mock model", node_desc=b"\x02@\x807\x10\x7fd\x00\x00*d\x00\x00", + patch_cluster=True, ): if endpoints is None: endpoints = { @@ -202,9 +204,18 @@ def zha_device_mock(hass, zigpy_device_mock): }, } zigpy_device = zigpy_device_mock( - endpoints, ieee, manufacturer, model, node_desc + endpoints, ieee, manufacturer, model, node_desc, patch_cluster=patch_cluster ) zha_device = zha_core_device.ZHADevice(hass, zigpy_device, MagicMock()) return zha_device return _zha_device + + +@pytest.fixture +def hass_disable_services(hass): + """Mock service register.""" + with patch.object(hass.services, "async_register"), patch.object( + hass.services, "has_service", return_value=True + ): + yield hass diff --git a/tests/components/zha/test_discover.py b/tests/components/zha/test_discover.py index 4b95040dd08..9fd01f1de8d 100644 --- a/tests/components/zha/test_discover.py +++ b/tests/components/zha/test_discover.py @@ -44,8 +44,11 @@ def channels_mock(zha_device_mock): manufacturer="mock manufacturer", model="mock model", node_desc=b"\x02@\x807\x10\x7fd\x00\x00*d\x00\x00", + patch_cluster=False, ): - zha_dev = zha_device_mock(endpoints, ieee, manufacturer, model, node_desc) + zha_dev = zha_device_mock( + endpoints, ieee, manufacturer, model, node_desc, patch_cluster=patch_cluster + ) channels = zha_channels.Channels.new(zha_dev) return channels @@ -58,12 +61,11 @@ def channels_mock(zha_device_mock): ) @pytest.mark.parametrize("device", DEVICES) async def test_devices( - device, hass, zigpy_device_mock, monkeypatch, zha_device_joined_restored + device, hass_disable_services, zigpy_device_mock, zha_device_joined_restored, ): """Test device discovery.""" - entity_registry = await homeassistant.helpers.entity_registry.async_get_registry( - hass + hass_disable_services ) zigpy_device = zigpy_device_mock( @@ -72,6 +74,7 @@ async def test_devices( device["manufacturer"], device["model"], node_descriptor=device["node_descriptor"], + patch_cluster=False, ) cluster_identify = _get_first_identify_cluster(zigpy_device) @@ -83,12 +86,12 @@ async def test_devices( try: zha_channels.ChannelPool.async_new_entity = lambda *a, **kw: _dispatch(*a, **kw) zha_dev = await zha_device_joined_restored(zigpy_device) - await hass.async_block_till_done() + await hass_disable_services.async_block_till_done() finally: zha_channels.ChannelPool.async_new_entity = orig_new_entity - entity_ids = hass.states.async_entity_ids() - await hass.async_block_till_done() + entity_ids = hass_disable_services.states.async_entity_ids() + await hass_disable_services.async_block_till_done() zha_entity_ids = { ent for ent in entity_ids if ent.split(".")[0] in zha_const.COMPONENTS } @@ -258,6 +261,7 @@ async def test_discover_endpoint(device_info, channels_mock, hass): manufacturer=device_info["manufacturer"], model=device_info["model"], node_desc=device_info["node_descriptor"], + patch_cluster=False, ) assert device_info["event_channels"] == sorted( @@ -364,7 +368,9 @@ def test_single_input_cluster_device_class_by_cluster_class(): ("switch", "switch.manufacturer_model_77665544_on_off"), ], ) -async def test_device_override(hass, zigpy_device_mock, setup_zha, override, entity_id): +async def test_device_override( + hass_disable_services, zigpy_device_mock, setup_zha, override, entity_id +): """Test device discovery override.""" zigpy_device = zigpy_device_mock( @@ -380,23 +386,26 @@ async def test_device_override(hass, zigpy_device_mock, setup_zha, override, ent "00:11:22:33:44:55:66:77", "manufacturer", "model", + patch_cluster=False, ) if override is not None: override = {"device_config": {"00:11:22:33:44:55:66:77-1": {"type": override}}} await setup_zha(override) - assert hass.states.get(entity_id) is None - zha_gateway = get_zha_gateway(hass) + assert hass_disable_services.states.get(entity_id) is None + zha_gateway = get_zha_gateway(hass_disable_services) await zha_gateway.async_device_initialized(zigpy_device) - await hass.async_block_till_done() - assert hass.states.get(entity_id) is not None + await hass_disable_services.async_block_till_done() + assert hass_disable_services.states.get(entity_id) is not None -async def test_group_probe_cleanup_called(hass, setup_zha, config_entry): +async def test_group_probe_cleanup_called( + hass_disable_services, setup_zha, config_entry +): """Test cleanup happens when zha is unloaded.""" await setup_zha() disc.GROUP_PROBE.cleanup = mock.Mock(wraps=disc.GROUP_PROBE.cleanup) - await config_entry.async_unload(hass) - await hass.async_block_till_done() + await config_entry.async_unload(hass_disable_services) + await hass_disable_services.async_block_till_done() disc.GROUP_PROBE.cleanup.assert_called()