diff --git a/homeassistant/components/unifi/services.py b/homeassistant/components/unifi/services.py index 2dce6f829b0..10d297df883 100644 --- a/homeassistant/components/unifi/services.py +++ b/homeassistant/components/unifi/services.py @@ -1,47 +1,89 @@ """UniFi services.""" +import voluptuous as vol + +from homeassistant.const import ATTR_DEVICE_ID from homeassistant.core import callback +from homeassistant.helpers.device_registry import CONNECTION_NETWORK_MAC from .const import DOMAIN as UNIFI_DOMAIN +SERVICE_RECONNECT_CLIENT = "reconnect_client" SERVICE_REMOVE_CLIENTS = "remove_clients" +SERVICE_RECONNECT_CLIENT_SCHEMA = vol.All( + vol.Schema({vol.Required(ATTR_DEVICE_ID): str}) +) + +SUPPORTED_SERVICES = (SERVICE_RECONNECT_CLIENT, SERVICE_REMOVE_CLIENTS) + +SERVICE_TO_SCHEMA = { + SERVICE_RECONNECT_CLIENT: SERVICE_RECONNECT_CLIENT_SCHEMA, +} + @callback def async_setup_services(hass) -> None: """Set up services for UniFi integration.""" + services = { + SERVICE_RECONNECT_CLIENT: async_reconnect_client, + SERVICE_REMOVE_CLIENTS: async_remove_clients, + } + async def async_call_unifi_service(service_call) -> None: """Call correct UniFi service.""" - service = service_call.service - service_data = service_call.data + await services[service_call.service](hass, service_call.data) - controllers = hass.data[UNIFI_DOMAIN].values() - - if service == SERVICE_REMOVE_CLIENTS: - await async_remove_clients(controllers, service_data) - - hass.services.async_register( - UNIFI_DOMAIN, - SERVICE_REMOVE_CLIENTS, - async_call_unifi_service, - ) + for service in SUPPORTED_SERVICES: + hass.services.async_register( + UNIFI_DOMAIN, + service, + async_call_unifi_service, + schema=SERVICE_TO_SCHEMA.get(service), + ) @callback def async_unload_services(hass) -> None: """Unload UniFi services.""" - hass.services.async_remove(UNIFI_DOMAIN, SERVICE_REMOVE_CLIENTS) + for service in SUPPORTED_SERVICES: + hass.services.async_remove(UNIFI_DOMAIN, service) -async def async_remove_clients(controllers, data) -> None: +async def async_reconnect_client(hass, data) -> None: + """Try to get wireless client to reconnect to Wi-Fi.""" + device_registry = await hass.helpers.device_registry.async_get_registry() + device_entry = device_registry.async_get(data[ATTR_DEVICE_ID]) + + mac = "" + for connection in device_entry.connections: + if connection[0] == CONNECTION_NETWORK_MAC: + mac = connection[1] + break + + if mac == "": + return + + for controller in hass.data[UNIFI_DOMAIN].values(): + if ( + not controller.available + or (client := controller.api.clients[mac]) is None + or client.is_wired + ): + continue + + await controller.api.clients.async_reconnect(mac) + + +async def async_remove_clients(hass, data) -> None: """Remove select clients from controller. Validates based on: - Total time between first seen and last seen is less than 15 minutes. - Neither IP, hostname nor name is configured. """ - for controller in controllers: + for controller in hass.data[UNIFI_DOMAIN].values(): if not controller.available: continue diff --git a/homeassistant/components/unifi/services.yaml b/homeassistant/components/unifi/services.yaml index 435661afd4a..7f06adc88a2 100644 --- a/homeassistant/components/unifi/services.yaml +++ b/homeassistant/components/unifi/services.yaml @@ -1,3 +1,15 @@ +reconnect_client: + name: Reconnect wireless client + description: Try to get wireless client to reconnect to UniFi network + fields: + device_id: + name: Device + description: Try reconnect client to wireless network + required: true + selector: + device: + integration: unifi + remove_clients: name: Remove clients from the UniFi Controller description: Clean up clients that has only been associated with the controller for a short period of time. diff --git a/tests/components/unifi/test_services.py b/tests/components/unifi/test_services.py index d9989e8733a..8fe41d7a856 100644 --- a/tests/components/unifi/test_services.py +++ b/tests/components/unifi/test_services.py @@ -3,7 +3,13 @@ from unittest.mock import patch from homeassistant.components.unifi.const import DOMAIN as UNIFI_DOMAIN -from homeassistant.components.unifi.services import SERVICE_REMOVE_CLIENTS +from homeassistant.components.unifi.services import ( + SERVICE_RECONNECT_CLIENT, + SERVICE_REMOVE_CLIENTS, + SUPPORTED_SERVICES, +) +from homeassistant.const import ATTR_DEVICE_ID +from homeassistant.helpers.device_registry import CONNECTION_NETWORK_MAC from .test_controller import setup_unifi_integration @@ -11,10 +17,12 @@ from .test_controller import setup_unifi_integration async def test_service_setup_and_unload(hass, aioclient_mock): """Verify service setup works.""" config_entry = await setup_unifi_integration(hass, aioclient_mock) - assert hass.services.has_service(UNIFI_DOMAIN, SERVICE_REMOVE_CLIENTS) + for service in SUPPORTED_SERVICES: + assert hass.services.has_service(UNIFI_DOMAIN, service) assert await hass.config_entries.async_unload(config_entry.entry_id) - assert not hass.services.has_service(UNIFI_DOMAIN, SERVICE_REMOVE_CLIENTS) + for service in SUPPORTED_SERVICES: + assert not hass.services.has_service(UNIFI_DOMAIN, service) @patch("homeassistant.core.ServiceRegistry.async_remove") @@ -33,7 +41,157 @@ async def test_service_setup_and_unload_not_called_if_multiple_integrations_dete assert await hass.config_entries.async_unload(config_entry_2.entry_id) remove_service_mock.assert_not_called() assert await hass.config_entries.async_unload(config_entry.entry_id) - remove_service_mock.assert_called_once() + assert remove_service_mock.call_count == 2 + + +async def test_reconnect_client(hass, aioclient_mock): + """Verify call to reconnect client is performed as expected.""" + clients = [ + { + "is_wired": False, + "mac": "00:00:00:00:00:01", + } + ] + config_entry = await setup_unifi_integration( + hass, aioclient_mock, clients_response=clients + ) + controller = hass.data[UNIFI_DOMAIN][config_entry.entry_id] + + aioclient_mock.clear_requests() + aioclient_mock.post( + f"https://{controller.host}:1234/api/s/{controller.site}/cmd/stamgr", + ) + + device_registry = await hass.helpers.device_registry.async_get_registry() + device_entry = device_registry.async_get_or_create( + config_entry_id=config_entry.entry_id, + connections={(CONNECTION_NETWORK_MAC, clients[0]["mac"])}, + ) + + await hass.services.async_call( + UNIFI_DOMAIN, + SERVICE_RECONNECT_CLIENT, + service_data={ATTR_DEVICE_ID: device_entry.id}, + blocking=True, + ) + assert aioclient_mock.call_count == 1 + + +async def test_reconnect_device_without_mac(hass, aioclient_mock): + """Verify no call is made if device does not have a known mac.""" + config_entry = await setup_unifi_integration(hass, aioclient_mock) + controller = hass.data[UNIFI_DOMAIN][config_entry.entry_id] + + aioclient_mock.clear_requests() + aioclient_mock.post( + f"https://{controller.host}:1234/api/s/{controller.site}/cmd/stamgr", + ) + + device_registry = await hass.helpers.device_registry.async_get_registry() + device_entry = device_registry.async_get_or_create( + config_entry_id=config_entry.entry_id, + connections={("other connection", "not mac")}, + ) + + await hass.services.async_call( + UNIFI_DOMAIN, + SERVICE_RECONNECT_CLIENT, + service_data={ATTR_DEVICE_ID: device_entry.id}, + blocking=True, + ) + assert aioclient_mock.call_count == 0 + + +async def test_reconnect_client_controller_unavailable(hass, aioclient_mock): + """Verify no call is made if controller is unavailable.""" + clients = [ + { + "is_wired": False, + "mac": "00:00:00:00:00:01", + } + ] + config_entry = await setup_unifi_integration( + hass, aioclient_mock, clients_response=clients + ) + controller = hass.data[UNIFI_DOMAIN][config_entry.entry_id] + controller.available = False + + aioclient_mock.clear_requests() + aioclient_mock.post( + f"https://{controller.host}:1234/api/s/{controller.site}/cmd/stamgr", + ) + + device_registry = await hass.helpers.device_registry.async_get_registry() + device_entry = device_registry.async_get_or_create( + config_entry_id=config_entry.entry_id, + connections={(CONNECTION_NETWORK_MAC, clients[0]["mac"])}, + ) + + await hass.services.async_call( + UNIFI_DOMAIN, + SERVICE_RECONNECT_CLIENT, + service_data={ATTR_DEVICE_ID: device_entry.id}, + blocking=True, + ) + assert aioclient_mock.call_count == 0 + + +async def test_reconnect_client_unknown_mac(hass, aioclient_mock): + """Verify no call is made if trying to reconnect a mac unknown to controller.""" + config_entry = await setup_unifi_integration(hass, aioclient_mock) + controller = hass.data[UNIFI_DOMAIN][config_entry.entry_id] + + aioclient_mock.clear_requests() + aioclient_mock.post( + f"https://{controller.host}:1234/api/s/{controller.site}/cmd/stamgr", + ) + + device_registry = await hass.helpers.device_registry.async_get_registry() + device_entry = device_registry.async_get_or_create( + config_entry_id=config_entry.entry_id, + connections={(CONNECTION_NETWORK_MAC, "mac unknown to controller")}, + ) + + await hass.services.async_call( + UNIFI_DOMAIN, + SERVICE_RECONNECT_CLIENT, + service_data={ATTR_DEVICE_ID: device_entry.id}, + blocking=True, + ) + assert aioclient_mock.call_count == 0 + + +async def test_reconnect_wired_client(hass, aioclient_mock): + """Verify no call is made if client is wired.""" + clients = [ + { + "is_wired": True, + "mac": "00:00:00:00:00:01", + } + ] + config_entry = await setup_unifi_integration( + hass, aioclient_mock, clients_response=clients + ) + controller = hass.data[UNIFI_DOMAIN][config_entry.entry_id] + + aioclient_mock.clear_requests() + aioclient_mock.post( + f"https://{controller.host}:1234/api/s/{controller.site}/cmd/stamgr", + ) + + device_registry = await hass.helpers.device_registry.async_get_registry() + device_entry = device_registry.async_get_or_create( + config_entry_id=config_entry.entry_id, + connections={(CONNECTION_NETWORK_MAC, clients[0]["mac"])}, + ) + + await hass.services.async_call( + UNIFI_DOMAIN, + SERVICE_RECONNECT_CLIENT, + service_data={ATTR_DEVICE_ID: device_entry.id}, + blocking=True, + ) + assert aioclient_mock.call_count == 0 async def test_remove_clients(hass, aioclient_mock):