Add UniFi Protect button (#63001)

This commit is contained in:
Christopher Bailey 2021-12-29 14:38:44 -05:00 committed by GitHub
parent b31041698f
commit 699512c36f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 146 additions and 3 deletions

View File

@ -0,0 +1,57 @@
"""Support for Ubiquiti's UniFi Protect NVR."""
from __future__ import annotations
import logging
from pyunifiprotect.data.base import ProtectAdoptableDeviceModel
from homeassistant.components.button import ButtonDeviceClass, ButtonEntity
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from .const import DEVICES_THAT_ADOPT, DOMAIN
from .data import ProtectData
from .entity import ProtectDeviceEntity
_LOGGER = logging.getLogger(__name__)
async def async_setup_entry(
hass: HomeAssistant,
entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Discover devices on a UniFi Protect NVR."""
data: ProtectData = hass.data[DOMAIN][entry.entry_id]
async_add_entities(
[
ProtectButton(
data,
device,
)
for device in data.get_by_types(DEVICES_THAT_ADOPT)
]
)
class ProtectButton(ProtectDeviceEntity, ButtonEntity):
"""A Ubiquiti UniFi Protect Reboot button."""
def __init__(
self,
data: ProtectData,
device: ProtectAdoptableDeviceModel,
) -> None:
"""Initialize an UniFi camera."""
super().__init__(data, device)
self._attr_name = f"{self.device.name} Reboot Device"
self._attr_entity_registry_enabled_default = False
self._attr_device_class = ButtonDeviceClass.RESTART
async def async_press(self) -> None:
"""Press the button."""
_LOGGER.debug("Rebooting %s with id %s", self.device.model, self.device.id)
await self.device.reboot()

View File

@ -41,4 +41,4 @@ DEVICES_FOR_SUBSCRIBE = DEVICES_WITH_ENTITIES | {ModelType.EVENT}
MIN_REQUIRED_PROTECT_V = Version("1.20.0") MIN_REQUIRED_PROTECT_V = Version("1.20.0")
OUTDATED_LOG_MESSAGE = "You are running v%s of UniFi Protect. Minimum required version is v%s. Please upgrade UniFi Protect and then retry" OUTDATED_LOG_MESSAGE = "You are running v%s of UniFi Protect. Minimum required version is v%s. Please upgrade UniFi Protect and then retry"
PLATFORMS = [Platform.CAMERA, Platform.MEDIA_PLAYER] PLATFORMS = [Platform.BUTTON, Platform.CAMERA, Platform.MEDIA_PLAYER]

View File

@ -2,13 +2,14 @@
from __future__ import annotations from __future__ import annotations
import collections import collections
from collections.abc import Generator, Iterable
from datetime import timedelta from datetime import timedelta
import logging import logging
from typing import Any from typing import Any
from pyunifiprotect import NotAuthorized, NvrError, ProtectApiClient from pyunifiprotect import NotAuthorized, NvrError, ProtectApiClient
from pyunifiprotect.data import Bootstrap, WSSubscriptionMessage from pyunifiprotect.data import Bootstrap, ModelType, WSSubscriptionMessage
from pyunifiprotect.data.base import ProtectDeviceModel from pyunifiprotect.data.base import ProtectAdoptableDeviceModel, ProtectDeviceModel
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback from homeassistant.core import CALLBACK_TYPE, HomeAssistant, callback
@ -49,6 +50,18 @@ class ProtectData:
"""Check if RTSP is disabled.""" """Check if RTSP is disabled."""
return self._entry.options.get(CONF_DISABLE_RTSP, False) return self._entry.options.get(CONF_DISABLE_RTSP, False)
def get_by_types(
self, device_types: Iterable[ModelType]
) -> Generator[ProtectAdoptableDeviceModel, None, None]:
"""Get all devices matching types."""
for device_type in device_types:
attr = f"{device_type.value}s"
devices: dict[str, ProtectAdoptableDeviceModel] = getattr(
self.api.bootstrap, attr
)
yield from devices.values()
async def async_setup(self) -> None: async def async_setup(self) -> None:
"""Subscribe and do the refresh.""" """Subscribe and do the refresh."""
self._unsub_websocket = self.api.subscribe_websocket( self._unsub_websocket = self.api.subscribe_websocket(

View File

@ -0,0 +1,73 @@
"""Test the UniFi Protect button platform."""
# pylint: disable=protected-access
from __future__ import annotations
from unittest.mock import AsyncMock, patch
import pytest
from pyunifiprotect.data import Camera
from homeassistant.components.unifiprotect.const import DEFAULT_ATTRIBUTION
from homeassistant.const import ATTR_ATTRIBUTION, ATTR_ENTITY_ID, Platform
from homeassistant.core import HomeAssistant
from homeassistant.helpers import entity_registry as er
from .conftest import MockEntityFixture, enable_entity
@pytest.fixture(name="camera")
async def camera_fixture(
hass: HomeAssistant, mock_entry: MockEntityFixture, mock_camera: Camera
):
"""Fixture for a single camera with only the button platform active, no extra setup."""
camera_obj = mock_camera.copy(deep=True)
camera_obj._api = mock_entry.api
camera_obj.channels[0]._api = mock_entry.api
camera_obj.channels[1]._api = mock_entry.api
camera_obj.channels[2]._api = mock_entry.api
camera_obj.name = "Test Camera"
mock_entry.api.bootstrap.cameras = {
camera_obj.id: camera_obj,
}
with patch("homeassistant.components.unifiprotect.PLATFORMS", [Platform.BUTTON]):
await hass.config_entries.async_setup(mock_entry.entry.entry_id)
await hass.async_block_till_done()
entity_registry = er.async_get(hass)
assert len(hass.states.async_all()) == 0
assert len(entity_registry.entities) == 1
yield (camera_obj, "button.test_camera_reboot_device")
async def test_button(
hass: HomeAssistant,
mock_entry: MockEntityFixture,
camera: tuple[Camera, str],
):
"""Test button entity."""
mock_entry.api.reboot_device = AsyncMock()
unique_id = f"{camera[0].id}"
entity_id = camera[1]
entity_registry = er.async_get(hass)
entity = entity_registry.async_get(entity_id)
assert entity
assert entity.disabled
assert entity.unique_id == unique_id
await enable_entity(hass, mock_entry.entry.entry_id, entity_id)
state = hass.states.get(entity_id)
assert state
assert state.attributes[ATTR_ATTRIBUTION] == DEFAULT_ATTRIBUTION
await hass.services.async_call(
"button", "press", {ATTR_ENTITY_ID: entity_id}, blocking=True
)
mock_entry.api.reboot_device.assert_called_once()