diff --git a/homeassistant/components/shelly/__init__.py b/homeassistant/components/shelly/__init__.py index c85728d3ee6..4109130ab80 100644 --- a/homeassistant/components/shelly/__init__.py +++ b/homeassistant/components/shelly/__init__.py @@ -29,6 +29,7 @@ from homeassistant.helpers.typing import ConfigType from .const import ( AIOSHELLY_DEVICE_TIMEOUT_SEC, + ATTR_BETA, ATTR_CHANNEL, ATTR_CLICK_TYPE, ATTR_DEVICE, @@ -66,6 +67,7 @@ from .utils import ( BLOCK_PLATFORMS: Final = [ "binary_sensor", + "button", "climate", "cover", "light", @@ -73,7 +75,7 @@ BLOCK_PLATFORMS: Final = [ "switch", ] BLOCK_SLEEPING_PLATFORMS: Final = ["binary_sensor", "sensor"] -RPC_PLATFORMS: Final = ["binary_sensor", "light", "sensor", "switch"] +RPC_PLATFORMS: Final = ["binary_sensor", "button", "light", "sensor", "switch"] _LOGGER: Final = logging.getLogger(__name__) COAP_SCHEMA: Final = vol.Schema( @@ -424,6 +426,41 @@ class BlockDeviceWrapper(update_coordinator.DataUpdateCoordinator): self.device_id = entry.id self.device.subscribe_updates(self.async_set_updated_data) + async def async_trigger_ota_update(self, beta: bool = False) -> None: + """Trigger or schedule an ota update.""" + update_data = self.device.status["update"] + _LOGGER.debug("OTA update service - update_data: %s", update_data) + + if not update_data["has_update"] and not beta: + _LOGGER.warning("No OTA update available for device %s", self.name) + return + + if beta and not update_data.get("beta_version"): + _LOGGER.warning( + "No OTA update on beta channel available for device %s", self.name + ) + return + + if update_data["status"] == "updating": + _LOGGER.warning("OTA update already in progress for %s", self.name) + return + + new_version = update_data["new_version"] + if beta: + new_version = update_data["beta_version"] + _LOGGER.info( + "Start OTA update of device %s from '%s' to '%s'", + self.name, + self.device.firmware_version, + new_version, + ) + try: + async with async_timeout.timeout(AIOSHELLY_DEVICE_TIMEOUT_SEC): + result = await self.device.trigger_ota_update(beta=beta) + except (asyncio.TimeoutError, OSError) as err: + _LOGGER.exception("Error while perform ota update: %s", err) + _LOGGER.debug("Result of OTA update call: %s", result) + def shutdown(self) -> None: """Shutdown the wrapper.""" self.device.shutdown() @@ -661,6 +698,42 @@ class RpcDeviceWrapper(update_coordinator.DataUpdateCoordinator): self.device_id = entry.id self.device.subscribe_updates(self.async_set_updated_data) + async def async_trigger_ota_update(self, beta: bool = False) -> None: + """Trigger an ota update.""" + + update_data = self.device.status["sys"]["available_updates"] + _LOGGER.debug("OTA update service - update_data: %s", update_data) + + if not bool(update_data) or (not update_data.get("stable") and not beta): + _LOGGER.warning("No OTA update available for device %s", self.name) + return + + if beta and not update_data.get(ATTR_BETA): + _LOGGER.warning( + "No OTA update on beta channel available for device %s", self.name + ) + return + + new_version = update_data.get("stable", {"version": ""})["version"] + if beta: + new_version = update_data.get(ATTR_BETA, {"version": ""})["version"] + + assert self.device.shelly + _LOGGER.info( + "Start OTA update of device %s from '%s' to '%s'", + self.name, + self.device.firmware_version, + new_version, + ) + result = None + try: + async with async_timeout.timeout(AIOSHELLY_DEVICE_TIMEOUT_SEC): + result = await self.device.trigger_ota_update(beta=beta) + except (asyncio.TimeoutError, OSError) as err: + _LOGGER.exception("Error while perform ota update: %s", err) + + _LOGGER.debug("Result of OTA update call: %s", result) + async def shutdown(self) -> None: """Shutdown the wrapper.""" await self.device.shutdown() diff --git a/homeassistant/components/shelly/button.py b/homeassistant/components/shelly/button.py new file mode 100644 index 00000000000..7890eefd30c --- /dev/null +++ b/homeassistant/components/shelly/button.py @@ -0,0 +1,101 @@ +"""Button for Shelly.""" +from __future__ import annotations + +from typing import cast + +from homeassistant.components.button import ButtonEntity +from homeassistant.config_entries import ConfigEntry +from homeassistant.const import ENTITY_CATEGORY_CONFIG +from homeassistant.core import HomeAssistant +from homeassistant.helpers.device_registry import CONNECTION_NETWORK_MAC +from homeassistant.helpers.entity import DeviceInfo +from homeassistant.helpers.entity_platform import AddEntitiesCallback +from homeassistant.util import slugify + +from . import BlockDeviceWrapper, RpcDeviceWrapper +from .const import BLOCK, DATA_CONFIG_ENTRY, DOMAIN, RPC +from .utils import get_block_device_name, get_device_entry_gen, get_rpc_device_name + + +async def async_setup_entry( + hass: HomeAssistant, + config_entry: ConfigEntry, + async_add_entities: AddEntitiesCallback, +) -> None: + """Set buttons for device.""" + wrapper: RpcDeviceWrapper | BlockDeviceWrapper | None = None + if get_device_entry_gen(config_entry) == 2: + if rpc_wrapper := hass.data[DOMAIN][DATA_CONFIG_ENTRY][ + config_entry.entry_id + ].get(RPC): + wrapper = cast(RpcDeviceWrapper, rpc_wrapper) + else: + if block_wrapper := hass.data[DOMAIN][DATA_CONFIG_ENTRY][ + config_entry.entry_id + ].get(BLOCK): + wrapper = cast(BlockDeviceWrapper, block_wrapper) + + if wrapper is not None: + async_add_entities( + [ + ShellyOtaUpdateStableButton(wrapper, config_entry), + ShellyOtaUpdateBetaButton(wrapper, config_entry), + ] + ) + + +class ShellyOtaUpdateBaseButton(ButtonEntity): + """Defines a Shelly OTA update base button.""" + + _attr_entity_category = ENTITY_CATEGORY_CONFIG + + def __init__( + self, + wrapper: RpcDeviceWrapper | BlockDeviceWrapper, + entry: ConfigEntry, + name: str, + beta_channel: bool, + icon: str, + ) -> None: + """Initialize Shelly OTA update button.""" + self._attr_device_info = DeviceInfo( + connections={(CONNECTION_NETWORK_MAC, wrapper.mac)} + ) + + if isinstance(wrapper, RpcDeviceWrapper): + device_name = get_rpc_device_name(wrapper.device) + else: + device_name = get_block_device_name(wrapper.device) + + self._attr_name = f"{device_name} {name}" + self._attr_unique_id = slugify(self._attr_name) + self._attr_icon = icon + + self.beta_channel = beta_channel + self.entry = entry + self.wrapper = wrapper + + async def async_press(self) -> None: + """Triggers the OTA update service.""" + await self.wrapper.async_trigger_ota_update(beta=self.beta_channel) + + +class ShellyOtaUpdateStableButton(ShellyOtaUpdateBaseButton): + """Defines a Shelly OTA update stable channel button.""" + + def __init__( + self, wrapper: RpcDeviceWrapper | BlockDeviceWrapper, entry: ConfigEntry + ) -> None: + """Initialize Shelly OTA update button.""" + super().__init__(wrapper, entry, "OTA Update", False, "mdi:package-up") + + +class ShellyOtaUpdateBetaButton(ShellyOtaUpdateBaseButton): + """Defines a Shelly OTA update beta channel button.""" + + def __init__( + self, wrapper: RpcDeviceWrapper | BlockDeviceWrapper, entry: ConfigEntry + ) -> None: + """Initialize Shelly OTA update button.""" + super().__init__(wrapper, entry, "OTA Update Beta", True, "mdi:flask-outline") + self._attr_entity_registry_enabled_default = False diff --git a/homeassistant/components/shelly/const.py b/homeassistant/components/shelly/const.py index fd06e88cdf1..8ef3ed5f1ac 100644 --- a/homeassistant/components/shelly/const.py +++ b/homeassistant/components/shelly/const.py @@ -90,6 +90,8 @@ ATTR_CHANNEL: Final = "channel" ATTR_DEVICE: Final = "device" ATTR_GENERATION: Final = "generation" CONF_SUBTYPE: Final = "subtype" +ATTR_BETA: Final = "beta" +CONF_OTA_BETA_CHANNEL: Final = "ota_beta_channel" BASIC_INPUTS_EVENTS_TYPES: Final = {"single", "long"} diff --git a/homeassistant/components/shelly/switch.py b/homeassistant/components/shelly/switch.py index 0291258b511..9114587a910 100644 --- a/homeassistant/components/shelly/switch.py +++ b/homeassistant/components/shelly/switch.py @@ -74,6 +74,7 @@ async def async_setup_rpc_entry( ) -> None: """Set up entities for RPC device.""" wrapper = hass.data[DOMAIN][DATA_CONFIG_ENTRY][config_entry.entry_id][RPC] + switch_key_ids = get_rpc_key_ids(wrapper.device.status, "switch") switch_ids = [] diff --git a/tests/components/shelly/conftest.py b/tests/components/shelly/conftest.py index a0d4a27bbc4..16dc8fdae9e 100644 --- a/tests/components/shelly/conftest.py +++ b/tests/components/shelly/conftest.py @@ -71,8 +71,25 @@ MOCK_SHELLY = { "num_outputs": 2, } -MOCK_STATUS = { +MOCK_STATUS_COAP = { + "update": { + "status": "pending", + "has_update": True, + "beta_version": "some_beta_version", + "new_version": "some_new_version", + "old_version": "some_old_version", + }, +} + + +MOCK_STATUS_RPC = { "switch:0": {"output": True}, + "sys": { + "available_updates": { + "beta": {"version": "some_beta_version"}, + "stable": {"version": "some_beta_version"}, + } + }, } @@ -117,8 +134,10 @@ async def coap_wrapper(hass): blocks=MOCK_BLOCKS, settings=MOCK_SETTINGS, shelly=MOCK_SHELLY, + status=MOCK_STATUS_COAP, firmware_version="some fw string", update=AsyncMock(), + trigger_ota_update=AsyncMock(), initialized=True, ) @@ -150,9 +169,10 @@ async def rpc_wrapper(hass): config=MOCK_CONFIG, event={}, shelly=MOCK_SHELLY, - status=MOCK_STATUS, + status=MOCK_STATUS_RPC, firmware_version="some fw string", update=AsyncMock(), + trigger_ota_update=AsyncMock(), initialized=True, shutdown=AsyncMock(), ) diff --git a/tests/components/shelly/test_button.py b/tests/components/shelly/test_button.py new file mode 100644 index 00000000000..5ceed08b9d9 --- /dev/null +++ b/tests/components/shelly/test_button.py @@ -0,0 +1,70 @@ +"""Tests for Shelly button platform.""" +from homeassistant.components.button import DOMAIN as BUTTON_DOMAIN +from homeassistant.components.button.const import SERVICE_PRESS +from homeassistant.const import ATTR_ENTITY_ID, STATE_UNKNOWN +from homeassistant.core import HomeAssistant +from homeassistant.helpers.entity_registry import async_get + + +async def test_block_button(hass: HomeAssistant, coap_wrapper): + """Test block device OTA button.""" + assert coap_wrapper + + hass.async_create_task( + hass.config_entries.async_forward_entry_setup(coap_wrapper.entry, BUTTON_DOMAIN) + ) + await hass.async_block_till_done() + + # stable channel button + state = hass.states.get("button.test_name_ota_update") + assert state + assert state.state == STATE_UNKNOWN + + await hass.services.async_call( + BUTTON_DOMAIN, + SERVICE_PRESS, + {ATTR_ENTITY_ID: "button.test_name_ota_update"}, + blocking=True, + ) + await hass.async_block_till_done() + coap_wrapper.device.trigger_ota_update.assert_called_once_with(beta=False) + + # beta channel button + entity_registry = async_get(hass) + entry = entity_registry.async_get("button.test_name_ota_update_beta") + state = hass.states.get("button.test_name_ota_update_beta") + + assert entry + assert state is None + + +async def test_rpc_button(hass: HomeAssistant, rpc_wrapper): + """Test rpc device OTA button.""" + assert rpc_wrapper + + hass.async_create_task( + hass.config_entries.async_forward_entry_setup(rpc_wrapper.entry, BUTTON_DOMAIN) + ) + await hass.async_block_till_done() + + # stable channel button + state = hass.states.get("button.test_name_ota_update") + assert state + assert state.state == STATE_UNKNOWN + + await hass.services.async_call( + BUTTON_DOMAIN, + SERVICE_PRESS, + {ATTR_ENTITY_ID: "button.test_name_ota_update"}, + blocking=True, + ) + await hass.async_block_till_done() + rpc_wrapper.device.trigger_ota_update.assert_called_once_with(beta=False) + + # beta channel button + entity_registry = async_get(hass) + entry = entity_registry.async_get("button.test_name_ota_update_beta") + state = hass.states.get("button.test_name_ota_update_beta") + + assert entry + assert state is None