diff --git a/homeassistant/components/unifi/button.py b/homeassistant/components/unifi/button.py new file mode 100644 index 00000000000..6b0660325f0 --- /dev/null +++ b/homeassistant/components/unifi/button.py @@ -0,0 +1,111 @@ +"""Button platform for UniFi Network integration. + +Support for restarting UniFi devices. +""" +from __future__ import annotations + +from collections.abc import Callable, Coroutine +from dataclasses import dataclass +from typing import Any, Generic + +import aiounifi +from aiounifi.interfaces.api_handlers import ItemEvent +from aiounifi.interfaces.devices import Devices +from aiounifi.models.api import ApiItemT +from aiounifi.models.device import Device, DeviceRestartRequest + +from homeassistant.components.button import ( + ButtonDeviceClass, + ButtonEntity, + ButtonEntityDescription, +) +from homeassistant.config_entries import ConfigEntry +from homeassistant.const import EntityCategory +from homeassistant.core import HomeAssistant, callback +from homeassistant.helpers.entity_platform import AddEntitiesCallback + +from .const import DOMAIN as UNIFI_DOMAIN +from .controller import UniFiController +from .entity import ( + HandlerT, + UnifiEntity, + UnifiEntityDescription, + async_device_available_fn, + async_device_device_info_fn, +) + + +@callback +async def async_restart_device_control_fn( + api: aiounifi.Controller, obj_id: str +) -> None: + """Restart device.""" + await api.request(DeviceRestartRequest.create(obj_id)) + + +@dataclass +class UnifiButtonEntityDescriptionMixin(Generic[HandlerT, ApiItemT]): + """Validate and load entities from different UniFi handlers.""" + + control_fn: Callable[[aiounifi.Controller, str], Coroutine[Any, Any, None]] + + +@dataclass +class UnifiButtonEntityDescription( + ButtonEntityDescription, + UnifiEntityDescription[HandlerT, ApiItemT], + UnifiButtonEntityDescriptionMixin[HandlerT, ApiItemT], +): + """Class describing UniFi button entity.""" + + +ENTITY_DESCRIPTIONS: tuple[UnifiButtonEntityDescription, ...] = ( + UnifiButtonEntityDescription[Devices, Device]( + key="Device restart", + entity_category=EntityCategory.CONFIG, + has_entity_name=True, + device_class=ButtonDeviceClass.RESTART, + allowed_fn=lambda controller, obj_id: True, + api_handler_fn=lambda api: api.devices, + available_fn=async_device_available_fn, + control_fn=async_restart_device_control_fn, + device_info_fn=async_device_device_info_fn, + event_is_on=None, + event_to_subscribe=None, + name_fn=lambda _: "Restart", + object_fn=lambda api, obj_id: api.devices[obj_id], + should_poll=False, + supported_fn=lambda controller, obj_id: True, + unique_id_fn=lambda controller, obj_id: f"device_restart-{obj_id}", + ), +) + + +async def async_setup_entry( + hass: HomeAssistant, + config_entry: ConfigEntry, + async_add_entities: AddEntitiesCallback, +) -> None: + """Set up button platform for UniFi Network integration.""" + controller: UniFiController = hass.data[UNIFI_DOMAIN][config_entry.entry_id] + + if controller.site_role != "admin": + return + + controller.register_platform_add_entities( + UnifiButtonEntity, ENTITY_DESCRIPTIONS, async_add_entities + ) + + +class UnifiButtonEntity(UnifiEntity[HandlerT, ApiItemT], ButtonEntity): + """Base representation of a UniFi image.""" + + entity_description: UnifiButtonEntityDescription[HandlerT, ApiItemT] + + async def async_press(self) -> None: + """Press the button.""" + await self.entity_description.control_fn(self.controller.api, self._obj_id) + + @callback + def async_update_state(self, event: ItemEvent, obj_id: str) -> None: + """Update entity state.""" diff --git a/homeassistant/components/unifi/const.py b/homeassistant/components/unifi/const.py index e03bd50d483..176511645aa 100644 --- a/homeassistant/components/unifi/const.py +++ b/homeassistant/components/unifi/const.py @@ -8,6 +8,7 @@ LOGGER = logging.getLogger(__package__) DOMAIN = "unifi" PLATFORMS = [ + Platform.BUTTON, Platform.DEVICE_TRACKER, Platform.IMAGE, Platform.SENSOR, diff --git a/tests/components/unifi/test_button.py b/tests/components/unifi/test_button.py new file mode 100644 index 00000000000..89b65b1f981 --- /dev/null +++ b/tests/components/unifi/test_button.py @@ -0,0 +1,92 @@ +"""UniFi Network button platform tests.""" + +from aiounifi.websocket import WebsocketState + +from homeassistant.components.button import ( + DOMAIN as BUTTON_DOMAIN, + ButtonDeviceClass, +) +from homeassistant.components.unifi.const import ( + DOMAIN as UNIFI_DOMAIN, +) +from homeassistant.const import ( + ATTR_DEVICE_CLASS, + STATE_UNAVAILABLE, + EntityCategory, +) +from homeassistant.core import HomeAssistant +from homeassistant.helpers import entity_registry as er + +from .test_controller import ( + setup_unifi_integration, +) + +from tests.test_util.aiohttp import AiohttpClientMocker + + +async def test_restart_device_button( + hass: HomeAssistant, aioclient_mock: AiohttpClientMocker, mock_unifi_websocket +) -> None: + """Test restarting device button.""" + config_entry = await setup_unifi_integration( + hass, + aioclient_mock, + devices_response=[ + { + "board_rev": 3, + "device_id": "mock-id", + "ip": "10.0.0.1", + "last_seen": 1562600145, + "mac": "00:00:00:00:01:01", + "model": "US16P150", + "name": "switch", + "state": 1, + "type": "usw", + "version": "4.0.42.10433", + } + ], + ) + controller = hass.data[UNIFI_DOMAIN][config_entry.entry_id] + + assert len(hass.states.async_entity_ids(BUTTON_DOMAIN)) == 1 + + ent_reg = er.async_get(hass) + ent_reg_entry = ent_reg.async_get("button.switch_restart") + assert ent_reg_entry.unique_id == "device_restart-00:00:00:00:01:01" + assert ent_reg_entry.entity_category is EntityCategory.CONFIG + + # Validate state object + button = hass.states.get("button.switch_restart") + assert button is not None + assert button.attributes.get(ATTR_DEVICE_CLASS) == ButtonDeviceClass.RESTART + + # Send restart device command + aioclient_mock.clear_requests() + aioclient_mock.post( + f"https://{controller.host}:1234/api/s/{controller.site}/cmd/devmgr", + ) + + await hass.services.async_call( + BUTTON_DOMAIN, + "press", + {"entity_id": "button.switch_restart"}, + blocking=True, + ) + assert aioclient_mock.call_count == 1 + assert aioclient_mock.mock_calls[0][2] == { + "cmd": "restart", + "mac": "00:00:00:00:01:01", + "reboot_type": "soft", + } + + # Availability signalling + + # Controller disconnects + mock_unifi_websocket(state=WebsocketState.DISCONNECTED) + await hass.async_block_till_done() + assert hass.states.get("button.switch_restart").state == STATE_UNAVAILABLE + + # Controller reconnects + mock_unifi_websocket(state=WebsocketState.RUNNING) + await hass.async_block_till_done() + assert hass.states.get("button.switch_restart").state != STATE_UNAVAILABLE diff --git a/tests/components/unifi/test_controller.py b/tests/components/unifi/test_controller.py index 5f1b5d33dcd..2d28240a90d 100644 --- a/tests/components/unifi/test_controller.py +++ b/tests/components/unifi/test_controller.py @@ -9,6 +9,7 @@ import aiounifi from aiounifi.websocket import WebsocketState import pytest +from homeassistant.components.button import DOMAIN as BUTTON_DOMAIN from homeassistant.components.device_tracker import DOMAIN as TRACKER_DOMAIN from homeassistant.components.image import DOMAIN as IMAGE_DOMAIN from homeassistant.components.sensor import DOMAIN as SENSOR_DOMAIN @@ -222,10 +223,11 @@ async def test_controller_setup( entry = controller.config_entry assert len(forward_entry_setup.mock_calls) == len(PLATFORMS) - assert forward_entry_setup.mock_calls[0][1] == (entry, TRACKER_DOMAIN) - assert forward_entry_setup.mock_calls[1][1] == (entry, IMAGE_DOMAIN) - assert forward_entry_setup.mock_calls[2][1] == (entry, SENSOR_DOMAIN) - assert forward_entry_setup.mock_calls[3][1] == (entry, SWITCH_DOMAIN) + assert forward_entry_setup.mock_calls[0][1] == (entry, BUTTON_DOMAIN) + assert forward_entry_setup.mock_calls[1][1] == (entry, TRACKER_DOMAIN) + assert forward_entry_setup.mock_calls[2][1] == (entry, IMAGE_DOMAIN) + assert forward_entry_setup.mock_calls[3][1] == (entry, SENSOR_DOMAIN) + assert forward_entry_setup.mock_calls[4][1] == (entry, SWITCH_DOMAIN) assert controller.host == ENTRY_CONFIG[CONF_HOST] assert controller.site == ENTRY_CONFIG[CONF_SITE_ID]