diff --git a/homeassistant/components/smlight/button.py b/homeassistant/components/smlight/button.py index de19c57d1b1..d82034b87fb 100644 --- a/homeassistant/components/smlight/button.py +++ b/homeassistant/components/smlight/button.py @@ -5,20 +5,22 @@ from __future__ import annotations from collections.abc import Awaitable, Callable from dataclasses import dataclass import logging -from typing import Final from pysmlight.web import CmdWrapper from homeassistant.components.button import ( + DOMAIN as BUTTON_DOMAIN, ButtonDeviceClass, ButtonEntity, ButtonEntityDescription, ) from homeassistant.config_entries import ConfigEntry from homeassistant.const import EntityCategory -from homeassistant.core import HomeAssistant +from homeassistant.core import HomeAssistant, callback +from homeassistant.helpers import entity_registry as er from homeassistant.helpers.entity_platform import AddEntitiesCallback +from .const import DOMAIN from .coordinator import SmDataUpdateCoordinator from .entity import SmEntity @@ -32,7 +34,7 @@ class SmButtonDescription(ButtonEntityDescription): press_fn: Callable[[CmdWrapper], Awaitable[None]] -BUTTONS: Final = [ +BUTTONS: list[SmButtonDescription] = [ SmButtonDescription( key="core_restart", translation_key="core_restart", @@ -53,6 +55,13 @@ BUTTONS: Final = [ ), ] +ROUTER = SmButtonDescription( + key="reconnect_zigbee_router", + translation_key="reconnect_zigbee_router", + entity_registry_enabled_default=False, + press_fn=lambda cmd: cmd.zb_router(), +) + async def async_setup_entry( hass: HomeAssistant, @@ -63,6 +72,24 @@ async def async_setup_entry( coordinator = entry.runtime_data.data async_add_entities(SmButton(coordinator, button) for button in BUTTONS) + entity_created = False + + @callback + def _check_router(startup: bool = False) -> None: + nonlocal entity_created + + if coordinator.data.info.zb_type == 1 and not entity_created: + async_add_entities([SmButton(coordinator, ROUTER)]) + entity_created = True + elif coordinator.data.info.zb_type != 1 and (startup or entity_created): + entity_registry = er.async_get(hass) + if entity_id := entity_registry.async_get_entity_id( + BUTTON_DOMAIN, DOMAIN, f"{coordinator.unique_id}-{ROUTER.key}" + ): + entity_registry.async_remove(entity_id) + + coordinator.async_add_listener(_check_router) + _check_router(startup=True) class SmButton(SmEntity, ButtonEntity): diff --git a/homeassistant/components/smlight/strings.json b/homeassistant/components/smlight/strings.json index 97797feae2a..1e6a533beef 100644 --- a/homeassistant/components/smlight/strings.json +++ b/homeassistant/components/smlight/strings.json @@ -108,6 +108,9 @@ }, "zigbee_flash_mode": { "name": "Zigbee flash mode" + }, + "reconnect_zigbee_router": { + "name": "Reconnect zigbee router" } }, "switch": { diff --git a/tests/components/smlight/test_button.py b/tests/components/smlight/test_button.py index 487351acdea..3721ee815e6 100644 --- a/tests/components/smlight/test_button.py +++ b/tests/components/smlight/test_button.py @@ -2,16 +2,19 @@ from unittest.mock import MagicMock +from freezegun.api import FrozenDateTimeFactory +from pysmlight import Info import pytest from homeassistant.components.button import DOMAIN as BUTTON_DOMAIN, SERVICE_PRESS +from homeassistant.components.smlight.const import SCAN_INTERVAL from homeassistant.const import ATTR_ENTITY_ID, STATE_UNKNOWN, Platform from homeassistant.core import HomeAssistant from homeassistant.helpers import entity_registry as er from .conftest import setup_integration -from tests.common import MockConfigEntry +from tests.common import MockConfigEntry, async_fire_time_changed @pytest.fixture @@ -20,12 +23,16 @@ def platforms() -> Platform | list[Platform]: return [Platform.BUTTON] +MOCK_ROUTER = Info(MAC="AA:BB:CC:DD:EE:FF", zb_type=1) + + @pytest.mark.parametrize( ("entity_id", "method"), [ ("core_restart", "reboot"), ("zigbee_flash_mode", "zb_bootloader"), ("zigbee_restart", "zb_restart"), + ("reconnect_zigbee_router", "zb_router"), ], ) @pytest.mark.usefixtures("entity_registry_enabled_by_default") @@ -38,6 +45,7 @@ async def test_buttons( mock_smlight_client: MagicMock, ) -> None: """Test creation of button entities.""" + mock_smlight_client.get_info.return_value = MOCK_ROUTER await setup_integration(hass, mock_config_entry) state = hass.states.get(f"button.mock_title_{entity_id}") @@ -61,17 +69,49 @@ async def test_buttons( mock_method.assert_called_with() -@pytest.mark.usefixtures("mock_smlight_client") -async def test_disabled_by_default_button( +@pytest.mark.parametrize("entity_id", ["zigbee_flash_mode", "reconnect_zigbee_router"]) +async def test_disabled_by_default_buttons( hass: HomeAssistant, + entity_id: str, entity_registry: er.EntityRegistry, mock_config_entry: MockConfigEntry, + mock_smlight_client: MagicMock, ) -> None: - """Test the disabled by default flash mode button.""" + """Test the disabled by default buttons.""" + mock_smlight_client.get_info.return_value = MOCK_ROUTER await setup_integration(hass, mock_config_entry) - assert not hass.states.get("button.mock_title_zigbee_flash_mode") + assert not hass.states.get(f"button.mock_{entity_id}") - assert (entry := entity_registry.async_get("button.mock_title_zigbee_flash_mode")) + assert (entry := entity_registry.async_get(f"button.mock_title_{entity_id}")) assert entry.disabled assert entry.disabled_by is er.RegistryEntryDisabler.INTEGRATION + + +async def test_remove_router_reconnect( + hass: HomeAssistant, + entity_registry: er.EntityRegistry, + freezer: FrozenDateTimeFactory, + mock_config_entry: MockConfigEntry, + mock_smlight_client: MagicMock, +) -> None: + """Test removal of orphaned router reconnect button.""" + save_mock = mock_smlight_client.get_info.return_value + mock_smlight_client.get_info.return_value = MOCK_ROUTER + mock_config_entry = await setup_integration(hass, mock_config_entry) + + entities = er.async_entries_for_config_entry( + entity_registry, mock_config_entry.entry_id + ) + assert len(entities) == 4 + assert entities[3].unique_id == "aa:bb:cc:dd:ee:ff-reconnect_zigbee_router" + + mock_smlight_client.get_info.return_value = save_mock + + freezer.tick(SCAN_INTERVAL) + async_fire_time_changed(hass) + + await hass.async_block_till_done() + + entity = entity_registry.async_get("button.mock_title_reconnect_zigbee_router") + assert entity is None