diff --git a/homeassistant/components/unifi/__init__.py b/homeassistant/components/unifi/__init__.py index 2394dfe92d8..03816c03df7 100644 --- a/homeassistant/components/unifi/__init__.py +++ b/homeassistant/components/unifi/__init__.py @@ -43,8 +43,10 @@ async def async_setup_entry(hass, config_entry): config_entry, unique_id=controller.site_id ) + if not hass.data[UNIFI_DOMAIN]: + async_setup_services(hass) + hass.data[UNIFI_DOMAIN][config_entry.entry_id] = controller - await async_setup_services(hass) config_entry.async_on_unload( hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, controller.shutdown) @@ -72,7 +74,7 @@ async def async_unload_entry(hass, config_entry): controller = hass.data[UNIFI_DOMAIN].pop(config_entry.entry_id) if not hass.data[UNIFI_DOMAIN]: - await async_unload_services(hass) + async_unload_services(hass) return await controller.async_reset() diff --git a/homeassistant/components/unifi/services.py b/homeassistant/components/unifi/services.py index dca95a764c3..2dce6f829b0 100644 --- a/homeassistant/components/unifi/services.py +++ b/homeassistant/components/unifi/services.py @@ -1,18 +1,15 @@ """UniFi services.""" -from .const import DOMAIN as UNIFI_DOMAIN +from homeassistant.core import callback -UNIFI_SERVICES = "unifi_services" +from .const import DOMAIN as UNIFI_DOMAIN SERVICE_REMOVE_CLIENTS = "remove_clients" -async def async_setup_services(hass) -> None: +@callback +def async_setup_services(hass) -> None: """Set up services for UniFi integration.""" - if hass.data.get(UNIFI_SERVICES, False): - return - - hass.data[UNIFI_SERVICES] = True async def async_call_unifi_service(service_call) -> None: """Call correct UniFi service.""" @@ -31,13 +28,9 @@ async def async_setup_services(hass) -> None: ) -async def async_unload_services(hass) -> None: +@callback +def async_unload_services(hass) -> None: """Unload UniFi services.""" - if not hass.data.get(UNIFI_SERVICES): - return - - hass.data[UNIFI_SERVICES] = False - hass.services.async_remove(UNIFI_DOMAIN, SERVICE_REMOVE_CLIENTS) diff --git a/tests/components/unifi/test_controller.py b/tests/components/unifi/test_controller.py index ec666ff27b9..02745dd60fb 100644 --- a/tests/components/unifi/test_controller.py +++ b/tests/components/unifi/test_controller.py @@ -166,6 +166,7 @@ async def setup_unifi_integration( known_wireless_clients=None, controllers=None, unique_id="1", + config_entry_id=DEFAULT_CONFIG_ENTRY_ID, ): """Create the UniFi controller.""" assert await async_setup_component(hass, UNIFI_DOMAIN, {}) @@ -175,7 +176,7 @@ async def setup_unifi_integration( data=deepcopy(config), options=deepcopy(options), unique_id=unique_id, - entry_id=DEFAULT_CONFIG_ENTRY_ID, + entry_id=config_entry_id, version=1, ) config_entry.add_to_hass(hass) diff --git a/tests/components/unifi/test_services.py b/tests/components/unifi/test_services.py index 388a33a4c64..d9989e8733a 100644 --- a/tests/components/unifi/test_services.py +++ b/tests/components/unifi/test_services.py @@ -1,58 +1,39 @@ """deCONZ service tests.""" -from unittest.mock import Mock, patch +from unittest.mock import patch from homeassistant.components.unifi.const import DOMAIN as UNIFI_DOMAIN -from homeassistant.components.unifi.services import ( - SERVICE_REMOVE_CLIENTS, - UNIFI_SERVICES, - async_setup_services, - async_unload_services, -) +from homeassistant.components.unifi.services import SERVICE_REMOVE_CLIENTS from .test_controller import setup_unifi_integration -async def test_service_setup(hass): +async def test_service_setup_and_unload(hass, aioclient_mock): """Verify service setup works.""" - assert UNIFI_SERVICES not in hass.data - with patch( - "homeassistant.core.ServiceRegistry.async_register", return_value=Mock(True) - ) as async_register: - await async_setup_services(hass) - assert hass.data[UNIFI_SERVICES] is True - assert async_register.call_count == 1 + config_entry = await setup_unifi_integration(hass, aioclient_mock) + assert hass.services.has_service(UNIFI_DOMAIN, SERVICE_REMOVE_CLIENTS) + + assert await hass.config_entries.async_unload(config_entry.entry_id) + assert not hass.services.has_service(UNIFI_DOMAIN, SERVICE_REMOVE_CLIENTS) -async def test_service_setup_already_registered(hass): - """Make sure that services are only registered once.""" - hass.data[UNIFI_SERVICES] = True - with patch( - "homeassistant.core.ServiceRegistry.async_register", return_value=Mock(True) - ) as async_register: - await async_setup_services(hass) - async_register.assert_not_called() +@patch("homeassistant.core.ServiceRegistry.async_remove") +@patch("homeassistant.core.ServiceRegistry.async_register") +async def test_service_setup_and_unload_not_called_if_multiple_integrations_detected( + register_service_mock, remove_service_mock, hass, aioclient_mock +): + """Make sure that services are only setup and removed once.""" + config_entry = await setup_unifi_integration(hass, aioclient_mock) + register_service_mock.reset_mock() + config_entry_2 = await setup_unifi_integration( + hass, aioclient_mock, config_entry_id=2 + ) + register_service_mock.assert_not_called() - -async def test_service_unload(hass): - """Verify service unload works.""" - hass.data[UNIFI_SERVICES] = True - with patch( - "homeassistant.core.ServiceRegistry.async_remove", return_value=Mock(True) - ) as async_remove: - await async_unload_services(hass) - assert hass.data[UNIFI_SERVICES] is False - assert async_remove.call_count == 1 - - -async def test_service_unload_not_registered(hass): - """Make sure that services can only be unloaded once.""" - with patch( - "homeassistant.core.ServiceRegistry.async_remove", return_value=Mock(True) - ) as async_remove: - await async_unload_services(hass) - assert UNIFI_SERVICES not in hass.data - async_remove.assert_not_called() + assert await hass.config_entries.async_unload(config_entry_2.entry_id) + remove_service_mock.assert_not_called() + assert await hass.config_entries.async_unload(config_entry.entry_id) + remove_service_mock.assert_called_once() async def test_remove_clients(hass, aioclient_mock): @@ -103,6 +84,8 @@ async def test_remove_clients(hass, aioclient_mock): "macs": ["00:00:00:00:00:01"], } + assert await hass.config_entries.async_unload(config_entry.entry_id) + async def test_remove_clients_controller_unavailable(hass, aioclient_mock): """Verify no call is made if controller is unavailable."""