diff --git a/homeassistant/components/unifiprotect/button.py b/homeassistant/components/unifiprotect/button.py new file mode 100644 index 00000000000..7d3d8bf136e --- /dev/null +++ b/homeassistant/components/unifiprotect/button.py @@ -0,0 +1,57 @@ +"""Support for Ubiquiti's UniFi Protect NVR.""" +from __future__ import annotations + +import logging + +from pyunifiprotect.data.base import ProtectAdoptableDeviceModel + +from homeassistant.components.button import ButtonDeviceClass, ButtonEntity +from homeassistant.config_entries import ConfigEntry +from homeassistant.core import HomeAssistant +from homeassistant.helpers.entity_platform import AddEntitiesCallback + +from .const import DEVICES_THAT_ADOPT, DOMAIN +from .data import ProtectData +from .entity import ProtectDeviceEntity + +_LOGGER = logging.getLogger(__name__) + + +async def async_setup_entry( + hass: HomeAssistant, + entry: ConfigEntry, + async_add_entities: AddEntitiesCallback, +) -> None: + """Discover devices on a UniFi Protect NVR.""" + data: ProtectData = hass.data[DOMAIN][entry.entry_id] + + async_add_entities( + [ + ProtectButton( + data, + device, + ) + for device in data.get_by_types(DEVICES_THAT_ADOPT) + ] + ) + + +class ProtectButton(ProtectDeviceEntity, ButtonEntity): + """A Ubiquiti UniFi Protect Reboot button.""" + + def __init__( + self, + data: ProtectData, + device: ProtectAdoptableDeviceModel, + ) -> None: + """Initialize an UniFi camera.""" + super().__init__(data, device) + self._attr_name = f"{self.device.name} Reboot Device" + self._attr_entity_registry_enabled_default = False + self._attr_device_class = ButtonDeviceClass.RESTART + + async def async_press(self) -> None: + """Press the button.""" + + _LOGGER.debug("Rebooting %s with id %s", self.device.model, self.device.id) + await self.device.reboot() diff --git a/homeassistant/components/unifiprotect/const.py b/homeassistant/components/unifiprotect/const.py index 1dcaa7ce2a3..a764a20c39c 100644 --- a/homeassistant/components/unifiprotect/const.py +++ b/homeassistant/components/unifiprotect/const.py @@ -41,4 +41,4 @@ DEVICES_FOR_SUBSCRIBE = DEVICES_WITH_ENTITIES | {ModelType.EVENT} MIN_REQUIRED_PROTECT_V = Version("1.20.0") OUTDATED_LOG_MESSAGE = "You are running v%s of UniFi Protect. Minimum required version is v%s. Please upgrade UniFi Protect and then retry" -PLATFORMS = [Platform.CAMERA, Platform.MEDIA_PLAYER] +PLATFORMS = [Platform.BUTTON, Platform.CAMERA, Platform.MEDIA_PLAYER] diff --git a/homeassistant/components/unifiprotect/data.py b/homeassistant/components/unifiprotect/data.py index 9004357b683..6f6e355ec35 100644 --- a/homeassistant/components/unifiprotect/data.py +++ b/homeassistant/components/unifiprotect/data.py @@ -2,13 +2,14 @@ from __future__ import annotations import collections +from collections.abc import Generator, Iterable from datetime import timedelta import logging from typing import Any from pyunifiprotect import NotAuthorized, NvrError, ProtectApiClient -from pyunifiprotect.data import Bootstrap, WSSubscriptionMessage -from pyunifiprotect.data.base import ProtectDeviceModel +from pyunifiprotect.data import Bootstrap, ModelType, WSSubscriptionMessage +from pyunifiprotect.data.base import ProtectAdoptableDeviceModel, ProtectDeviceModel from homeassistant.config_entries import ConfigEntry from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback @@ -49,6 +50,18 @@ class ProtectData: """Check if RTSP is disabled.""" return self._entry.options.get(CONF_DISABLE_RTSP, False) + def get_by_types( + self, device_types: Iterable[ModelType] + ) -> Generator[ProtectAdoptableDeviceModel, None, None]: + """Get all devices matching types.""" + + for device_type in device_types: + attr = f"{device_type.value}s" + devices: dict[str, ProtectAdoptableDeviceModel] = getattr( + self.api.bootstrap, attr + ) + yield from devices.values() + async def async_setup(self) -> None: """Subscribe and do the refresh.""" self._unsub_websocket = self.api.subscribe_websocket( diff --git a/tests/components/unifiprotect/test_button.py b/tests/components/unifiprotect/test_button.py new file mode 100644 index 00000000000..931765dd0bf --- /dev/null +++ b/tests/components/unifiprotect/test_button.py @@ -0,0 +1,73 @@ +"""Test the UniFi Protect button platform.""" +# pylint: disable=protected-access +from __future__ import annotations + +from unittest.mock import AsyncMock, patch + +import pytest +from pyunifiprotect.data import Camera + +from homeassistant.components.unifiprotect.const import DEFAULT_ATTRIBUTION +from homeassistant.const import ATTR_ATTRIBUTION, ATTR_ENTITY_ID, Platform +from homeassistant.core import HomeAssistant +from homeassistant.helpers import entity_registry as er + +from .conftest import MockEntityFixture, enable_entity + + +@pytest.fixture(name="camera") +async def camera_fixture( + hass: HomeAssistant, mock_entry: MockEntityFixture, mock_camera: Camera +): + """Fixture for a single camera with only the button platform active, no extra setup.""" + + camera_obj = mock_camera.copy(deep=True) + camera_obj._api = mock_entry.api + camera_obj.channels[0]._api = mock_entry.api + camera_obj.channels[1]._api = mock_entry.api + camera_obj.channels[2]._api = mock_entry.api + camera_obj.name = "Test Camera" + + mock_entry.api.bootstrap.cameras = { + camera_obj.id: camera_obj, + } + + with patch("homeassistant.components.unifiprotect.PLATFORMS", [Platform.BUTTON]): + await hass.config_entries.async_setup(mock_entry.entry.entry_id) + await hass.async_block_till_done() + + entity_registry = er.async_get(hass) + + assert len(hass.states.async_all()) == 0 + assert len(entity_registry.entities) == 1 + + yield (camera_obj, "button.test_camera_reboot_device") + + +async def test_button( + hass: HomeAssistant, + mock_entry: MockEntityFixture, + camera: tuple[Camera, str], +): + """Test button entity.""" + + mock_entry.api.reboot_device = AsyncMock() + + unique_id = f"{camera[0].id}" + entity_id = camera[1] + + entity_registry = er.async_get(hass) + entity = entity_registry.async_get(entity_id) + assert entity + assert entity.disabled + assert entity.unique_id == unique_id + + await enable_entity(hass, mock_entry.entry.entry_id, entity_id) + state = hass.states.get(entity_id) + assert state + assert state.attributes[ATTR_ATTRIBUTION] == DEFAULT_ATTRIBUTION + + await hass.services.async_call( + "button", "press", {ATTR_ENTITY_ID: entity_id}, blocking=True + ) + mock_entry.api.reboot_device.assert_called_once()