diff --git a/homeassistant/components/zha/core/cluster_handlers/__init__.py b/homeassistant/components/zha/core/cluster_handlers/__init__.py index 94cd1f49ca8..1f7485d4922 100644 --- a/homeassistant/components/zha/core/cluster_handlers/__init__.py +++ b/homeassistant/components/zha/core/cluster_handlers/__init__.py @@ -627,8 +627,9 @@ class ClientClusterHandler(ClusterHandler): """ClusterHandler for Zigbee client (output) clusters.""" @callback - def attribute_updated(self, attrid: int, value: Any, _: Any) -> None: + def attribute_updated(self, attrid: int, value: Any, timestamp: Any) -> None: """Handle an attribute updated on this cluster.""" + super().attribute_updated(attrid, value, timestamp) try: attr_name = self._cluster.attributes[attrid].name diff --git a/homeassistant/components/zha/core/cluster_handlers/general.py b/homeassistant/components/zha/core/cluster_handlers/general.py index 14401b260b2..d2927f6d028 100644 --- a/homeassistant/components/zha/core/cluster_handlers/general.py +++ b/homeassistant/components/zha/core/cluster_handlers/general.py @@ -56,7 +56,6 @@ from ..const import ( SIGNAL_MOVE_LEVEL, SIGNAL_SET_LEVEL, SIGNAL_UPDATE_DEVICE, - UNKNOWN as ZHA_UNKNOWN, ) from . import ( AttrReportConfig, @@ -538,14 +537,9 @@ class OtaClusterHandler(ClusterHandler): } @property - def current_file_version(self) -> str: + def current_file_version(self) -> int | None: """Return cached value of current_file_version attribute.""" - current_file_version = self.cluster.get( - Ota.AttributeDefs.current_file_version.name - ) - if current_file_version is not None: - return f"0x{int(current_file_version):08x}" - return ZHA_UNKNOWN + return self.cluster.get(Ota.AttributeDefs.current_file_version.name) @registries.CLIENT_CLUSTER_HANDLER_REGISTRY.register(Ota.cluster_id) @@ -559,36 +553,31 @@ class OtaClientClusterHandler(ClientClusterHandler): } @property - def current_file_version(self) -> str: + def current_file_version(self) -> int | None: """Return cached value of current_file_version attribute.""" - current_file_version = self.cluster.get( - Ota.AttributeDefs.current_file_version.name - ) - if current_file_version is not None: - return f"0x{int(current_file_version):08x}" - return ZHA_UNKNOWN + return self.cluster.get(Ota.AttributeDefs.current_file_version.name) @callback def cluster_command( self, tsn: int, command_id: int, args: list[Any] | None ) -> None: """Handle OTA commands.""" - if command_id in self.cluster.server_commands: - cmd_name = self.cluster.server_commands[command_id].name - else: - cmd_name = command_id + if command_id not in self.cluster.server_commands: + return signal_id = self._endpoint.unique_id.split("-")[0] + cmd_name = self.cluster.server_commands[command_id].name + if cmd_name == Ota.ServerCommandDefs.query_next_image.name: assert args - self.async_send_signal(SIGNAL_UPDATE_DEVICE.format(signal_id), args[3]) - async def async_check_for_update(self): - """Check for firmware availability by issuing an image notify command.""" - await self.cluster.image_notify( - payload_type=(self.cluster.ImageNotifyCommand.PayloadType.QueryJitter), - query_jitter=100, - ) + current_file_version = args[3] + self.cluster.update_attribute( + Ota.AttributeDefs.current_file_version.id, current_file_version + ) + self.async_send_signal( + SIGNAL_UPDATE_DEVICE.format(signal_id), current_file_version + ) @registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.register(Partition.cluster_id) diff --git a/homeassistant/components/zha/core/discovery.py b/homeassistant/components/zha/core/discovery.py index 06dbfa46a7e..5575d633593 100644 --- a/homeassistant/components/zha/core/discovery.py +++ b/homeassistant/components/zha/core/discovery.py @@ -75,11 +75,12 @@ async def async_add_entities( tuple[str, ZHADevice, list[ClusterHandler]], ] ], + **kwargs, ) -> None: """Add entities helper.""" if not entities: return - to_add = [ent_cls.create_entity(*args) for ent_cls, args in entities] + to_add = [ent_cls.create_entity(*args, **kwargs) for ent_cls, args in entities] entities_to_add = [entity for entity in to_add if entity is not None] _async_add_entities(entities_to_add, update_before_add=False) entities.clear() diff --git a/homeassistant/components/zha/manifest.json b/homeassistant/components/zha/manifest.json index 216947515a1..3a1df4207ac 100644 --- a/homeassistant/components/zha/manifest.json +++ b/homeassistant/components/zha/manifest.json @@ -27,7 +27,7 @@ "pyserial-asyncio==0.6", "zha-quirks==0.0.112", "zigpy-deconz==0.23.1", - "zigpy==0.62.3", + "zigpy==0.63.2", "zigpy-xbee==0.20.1", "zigpy-zigate==0.12.0", "zigpy-znp==0.12.1", diff --git a/homeassistant/components/zha/update.py b/homeassistant/components/zha/update.py index e92424acf47..d45c24253be 100644 --- a/homeassistant/components/zha/update.py +++ b/homeassistant/components/zha/update.py @@ -1,17 +1,16 @@ """Representation of ZHA updates.""" from __future__ import annotations -from dataclasses import dataclass import functools +import logging +import math from typing import TYPE_CHECKING, Any -from zigpy.ota.image import BaseOTAImage -from zigpy.types import uint16_t +from zigpy.ota import OtaImageWithMetadata +from zigpy.zcl.clusters.general import Ota from zigpy.zcl.foundation import Status from homeassistant.components.update import ( - ATTR_INSTALLED_VERSION, - ATTR_LATEST_VERSION, UpdateDeviceClass, UpdateEntity, UpdateEntityFeature, @@ -22,36 +21,29 @@ from homeassistant.core import HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.entity_platform import AddEntitiesCallback -from homeassistant.helpers.restore_state import ExtraStoredData +from homeassistant.helpers.update_coordinator import ( + CoordinatorEntity, + DataUpdateCoordinator, +) from .core import discovery -from .core.const import CLUSTER_HANDLER_OTA, SIGNAL_ADD_ENTITIES, UNKNOWN -from .core.helpers import get_zha_data +from .core.const import CLUSTER_HANDLER_OTA, SIGNAL_ADD_ENTITIES, SIGNAL_ATTR_UPDATED +from .core.helpers import get_zha_data, get_zha_gateway from .core.registries import ZHA_ENTITIES from .entity import ZhaEntity if TYPE_CHECKING: + from zigpy.application import ControllerApplication + from .core.cluster_handlers import ClusterHandler from .core.device import ZHADevice +_LOGGER = logging.getLogger(__name__) + CONFIG_DIAGNOSTIC_MATCH = functools.partial( ZHA_ENTITIES.config_diagnostic_match, Platform.UPDATE ) -# don't let homeassistant check for updates button hammer the zigbee network -PARALLEL_UPDATES = 1 - - -@dataclass -class ZHAFirmwareUpdateExtraStoredData(ExtraStoredData): - """Extra stored data for ZHA firmware update entity.""" - - image_type: uint16_t | None - - def as_dict(self) -> dict[str, Any]: - """Return a dict representation of the extra data.""" - return {"image_type": self.image_type} - async def async_setup_entry( hass: HomeAssistant, @@ -62,18 +54,46 @@ async def async_setup_entry( zha_data = get_zha_data(hass) entities_to_create = zha_data.platforms[Platform.UPDATE] + coordinator = ZHAFirmwareUpdateCoordinator( + hass, get_zha_gateway(hass).application_controller + ) + unsub = async_dispatcher_connect( hass, SIGNAL_ADD_ENTITIES, functools.partial( - discovery.async_add_entities, async_add_entities, entities_to_create + discovery.async_add_entities, + async_add_entities, + entities_to_create, + coordinator=coordinator, ), ) config_entry.async_on_unload(unsub) +class ZHAFirmwareUpdateCoordinator(DataUpdateCoordinator): # pylint: disable=hass-enforce-coordinator-module + """Firmware update coordinator that broadcasts updates network-wide.""" + + def __init__( + self, hass: HomeAssistant, controller_application: ControllerApplication + ) -> None: + """Initialize the coordinator.""" + super().__init__( + hass, + _LOGGER, + name="ZHA firmware update coordinator", + update_method=self.async_update_data, + ) + self.controller_application = controller_application + + async def async_update_data(self) -> None: + """Fetch the latest firmware update data.""" + # Broadcast to all devices + await self.controller_application.ota.broadcast_notify(jitter=100) + + @CONFIG_DIAGNOSTIC_MATCH(cluster_handler_names=CLUSTER_HANDLER_OTA) -class ZHAFirmwareUpdateEntity(ZhaEntity, UpdateEntity): +class ZHAFirmwareUpdateEntity(ZhaEntity, CoordinatorEntity, UpdateEntity): """Representation of a ZHA firmware update entity.""" _unique_id_suffix = "firmware_update" @@ -90,147 +110,114 @@ class ZHAFirmwareUpdateEntity(ZhaEntity, UpdateEntity): unique_id: str, zha_device: ZHADevice, channels: list[ClusterHandler], + coordinator: ZHAFirmwareUpdateCoordinator, **kwargs: Any, ) -> None: """Initialize the ZHA update entity.""" super().__init__(unique_id, zha_device, channels, **kwargs) + CoordinatorEntity.__init__(self, coordinator) + self._ota_cluster_handler: ClusterHandler = self.cluster_handlers[ CLUSTER_HANDLER_OTA ] - self._attr_installed_version: str = self.determine_installed_version() - self._image_type: uint16_t | None = None - self._latest_version_firmware: BaseOTAImage | None = None - self._result = None + self._attr_installed_version: str | None = self._get_cluster_version() + self._attr_latest_version = self._attr_installed_version + self._latest_firmware: OtaImageWithMetadata | None = None + + def _get_cluster_version(self) -> str | None: + """Synchronize current file version with the cluster.""" + + device = self._ota_cluster_handler._endpoint.device # pylint: disable=protected-access + + if self._ota_cluster_handler.current_file_version is not None: + return f"0x{self._ota_cluster_handler.current_file_version:08x}" + + if device.sw_version is not None: + return device.sw_version + + return None @callback - def determine_installed_version(self) -> str: - """Determine the currently installed firmware version.""" - currently_installed_version = self._ota_cluster_handler.current_file_version - version_from_dr = self.zha_device.sw_version - if currently_installed_version == UNKNOWN and version_from_dr: - currently_installed_version = version_from_dr - return currently_installed_version - - @property - def extra_restore_state_data(self) -> ZHAFirmwareUpdateExtraStoredData: - """Return ZHA firmware update specific state data to be restored.""" - return ZHAFirmwareUpdateExtraStoredData(self._image_type) + def attribute_updated(self, attrid: int, name: str, value: Any) -> None: + """Handle attribute updates on the OTA cluster.""" + if attrid == Ota.AttributeDefs.current_file_version.id: + self._attr_installed_version = f"0x{value:08x}" + self.async_write_ha_state() @callback - def device_ota_update_available(self, image: BaseOTAImage) -> None: + def device_ota_update_available( + self, image: OtaImageWithMetadata, current_file_version: int + ) -> None: """Handle ota update available signal from Zigpy.""" - self._latest_version_firmware = image - self._attr_latest_version = f"0x{image.header.file_version:08x}" - self._image_type = image.header.image_type - self._attr_installed_version = self.determine_installed_version() + self._latest_firmware = image + self._attr_latest_version = f"0x{image.version:08x}" + self._attr_installed_version = f"0x{current_file_version:08x}" + + if image.metadata.changelog: + self._attr_release_summary = image.metadata.changelog + self.async_write_ha_state() @callback def _update_progress(self, current: int, total: int, progress: float) -> None: """Update install progress on event.""" - assert self._latest_version_firmware - self._attr_in_progress = int(progress) + # If we are not supposed to be updating, do nothing + if self._attr_in_progress is False: + return + + # Remap progress to 2-100 to avoid 0 and 1 + self._attr_in_progress = int(math.ceil(2 + 98 * progress / 100)) self.async_write_ha_state() - @callback - def _reset_progress(self, write_state: bool = True) -> None: - """Reset update install progress.""" - self._result = None - self._attr_in_progress = False - if write_state: - self.async_write_ha_state() - - async def async_update(self) -> None: - """Handle the update entity service call to manually check for available firmware updates.""" - await super().async_update() - # check for updates in the HA settings menu can invoke this so we need to check if the device - # is mains powered so we don't get a ton of errors in the logs from sleepy devices. - if self.zha_device.available and self.zha_device.is_mains_powered: - await self._ota_cluster_handler.async_check_for_update() - async def async_install( self, version: str | None, backup: bool, **kwargs: Any ) -> None: """Install an update.""" - firmware = self._latest_version_firmware - assert firmware - self._reset_progress(False) + assert self._latest_firmware is not None + + # Set the progress to an indeterminate state self._attr_in_progress = True self.async_write_ha_state() try: - self._result = await self.zha_device.device.update_firmware( - self._latest_version_firmware, - self._update_progress, + result = await self.zha_device.device.update_firmware( + image=self._latest_firmware, + progress_callback=self._update_progress, ) except Exception as ex: - self._reset_progress() - raise HomeAssistantError(ex) from ex + raise HomeAssistantError(f"Update was not successful: {ex}") from ex - assert self._result is not None + # If we tried to install firmware that is no longer compatible with the device, + # bail out + if result == Status.NO_IMAGE_AVAILABLE: + self._attr_latest_version = self._attr_installed_version + self.async_write_ha_state() - # If the update was not successful, we should throw an error to let the user know - if self._result != Status.SUCCESS: - # save result since reset_progress will clear it - results = self._result - self._reset_progress() - raise HomeAssistantError(f"Update was not successful - result: {results}") + # If the update finished but was not successful, we should also throw an error + if result != Status.SUCCESS: + raise HomeAssistantError(f"Update was not successful: {result}") - # If we get here, all files were installed successfully - self._attr_installed_version = ( - self._attr_latest_version - ) = f"0x{firmware.header.file_version:08x}" - self._latest_version_firmware = None - self._reset_progress() + # Clear the state + self._latest_firmware = None + self._attr_in_progress = False + self.async_write_ha_state() async def async_added_to_hass(self) -> None: """Call when entity is added.""" await super().async_added_to_hass() - last_state = await self.async_get_last_state() - # If we have a complete previous state, use that to set the installed version - if ( - last_state - and self._attr_installed_version == UNKNOWN - and (installed_version := last_state.attributes.get(ATTR_INSTALLED_VERSION)) - ): - self._attr_installed_version = installed_version - # If we have a complete previous state, use that to set the latest version - if ( - last_state - and (latest_version := last_state.attributes.get(ATTR_LATEST_VERSION)) - is not None - and latest_version != UNKNOWN - ): - self._attr_latest_version = latest_version - # If we have no state or latest version to restore, or the latest version is - # the same as the installed version, we can set the latest - # version to installed so that the entity starts as off. - elif ( - not last_state - or not latest_version - or latest_version == self._attr_installed_version - ): - self._attr_latest_version = self._attr_installed_version - - if self._attr_latest_version != self._attr_installed_version and ( - extra_data := await self.async_get_last_extra_data() - ): - self._image_type = extra_data.as_dict()["image_type"] - if self._image_type: - self._latest_version_firmware = ( - await self.zha_device.device.application.ota.get_ota_image( - self.zha_device.manufacturer_code, self._image_type - ) - ) - # if we can't locate an image but we have a latest version that differs - # we should set the latest version to the installed version to avoid - # confusion and errors - if not self._latest_version_firmware: - self._attr_latest_version = self._attr_installed_version + # OTA events are sent by the device self.zha_device.device.add_listener(self) + self.async_accept_signal( + self._ota_cluster_handler, SIGNAL_ATTR_UPDATED, self.attribute_updated + ) async def async_will_remove_from_hass(self) -> None: """Call when entity will be removed.""" await super().async_will_remove_from_hass() - self._reset_progress(False) + self._attr_in_progress = False + + async def async_update(self) -> None: + """Update the entity.""" + await CoordinatorEntity.async_update(self) + await super().async_update() diff --git a/requirements_all.txt b/requirements_all.txt index 39a53408be7..8b456bcef33 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -2950,7 +2950,7 @@ zigpy-zigate==0.12.0 zigpy-znp==0.12.1 # homeassistant.components.zha -zigpy==0.62.3 +zigpy==0.63.2 # homeassistant.components.zoneminder zm-py==0.5.4 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index a73492d5581..6dde980cb46 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -2270,7 +2270,7 @@ zigpy-zigate==0.12.0 zigpy-znp==0.12.1 # homeassistant.components.zha -zigpy==0.62.3 +zigpy==0.63.2 # homeassistant.components.zwave_js zwave-js-server-python==0.55.3 diff --git a/tests/components/zha/conftest.py b/tests/components/zha/conftest.py index f29bad8b3af..36d0cbcff97 100644 --- a/tests/components/zha/conftest.py +++ b/tests/components/zha/conftest.py @@ -156,12 +156,7 @@ async def zigpy_app_controller(): zigpy.config.CONF_NWK_BACKUP_ENABLED: False, zigpy.config.CONF_TOPO_SCAN_ENABLED: False, zigpy.config.CONF_OTA: { - zigpy.config.CONF_OTA_IKEA: False, - zigpy.config.CONF_OTA_INOVELLI: False, - zigpy.config.CONF_OTA_LEDVANCE: False, - zigpy.config.CONF_OTA_SALUS: False, - zigpy.config.CONF_OTA_SONOFF: False, - zigpy.config.CONF_OTA_THIRDREALITY: False, + zigpy.config.CONF_OTA_ENABLED: False, }, } ) diff --git a/tests/components/zha/test_discover.py b/tests/components/zha/test_discover.py index 12b0456f2e2..1491b46005b 100644 --- a/tests/components/zha/test_discover.py +++ b/tests/components/zha/test_discover.py @@ -24,7 +24,8 @@ from homeassistant.components.zha.core.helpers import get_zha_gateway import homeassistant.components.zha.core.registries as zha_regs from homeassistant.const import Platform from homeassistant.core import HomeAssistant -import homeassistant.helpers.entity_registry as er +from homeassistant.helpers import entity_registry as er +from homeassistant.helpers.entity_platform import EntityPlatform from .conftest import SIG_EP_INPUT, SIG_EP_OUTPUT, SIG_EP_PROFILE, SIG_EP_TYPE from .zha_devices_list import ( @@ -80,8 +81,6 @@ async def test_devices( zha_device_joined_restored, ) -> None: """Test device discovery.""" - entity_registry = er.async_get(hass_disable_services) - zigpy_device = zigpy_device_mock( endpoints=device[SIG_ENDPOINTS], ieee="00:11:22:33:44:55:66:77", @@ -96,14 +95,13 @@ async def test_devices( if cluster_identify: cluster_identify.request.reset_mock() - orig_new_entity = Endpoint.async_new_entity - _dispatch = mock.MagicMock(wraps=orig_new_entity) - try: - Endpoint.async_new_entity = lambda *a, **kw: _dispatch(*a, **kw) + with patch( + "homeassistant.helpers.entity_platform.EntityPlatform._async_schedule_add_entities_for_entry", + side_effect=EntityPlatform._async_schedule_add_entities_for_entry, + autospec=True, + ) as mock_add_entities: zha_dev = await zha_device_joined_restored(zigpy_device) await hass_disable_services.async_block_till_done() - finally: - Endpoint.async_new_entity = orig_new_entity if cluster_identify: # We only identify on join @@ -136,60 +134,38 @@ async def test_devices( for ch in endpoint.client_cluster_handlers.values() } assert event_cluster_handlers == set(device[DEV_SIG_EVT_CLUSTER_HANDLERS]) - # we need to probe the class create entity factory so we need to reset this to get accurate results - zha_regs.ZHA_ENTITIES.clean_up() - # build a dict of entity_class -> (platform, unique_id, cluster_handlers) tuple - ha_ent_info = {} - created_entity_count = 0 - for call in _dispatch.call_args_list: - _, platform, entity_cls, unique_id, cluster_handlers = call[0] - # the factory can return None. We filter these out to get an accurate created entity count - response = entity_cls.create_entity(unique_id, zha_dev, cluster_handlers) - if response and not contains_ignored_suffix(response.unique_id): - created_entity_count += 1 - unique_id_head = UNIQUE_ID_HD.match(unique_id).group( - 0 - ) # ieee + endpoint_id - ha_ent_info[(unique_id_head, entity_cls.__name__)] = ( - platform, - unique_id, - cluster_handlers, - ) - for comp_id, ent_info in device[DEV_SIG_ENT_MAP].items(): - platform, unique_id = comp_id + # Keep track of unhandled entities: they should always be ones we explicitly ignore + created_entities = { + entity.entity_id: entity + for mock_call in mock_add_entities.mock_calls + for entity in mock_call.args[1] + } + unhandled_entities = set(created_entities.keys()) + entity_registry = er.async_get(hass_disable_services) + + for (platform, unique_id), ent_info in device[DEV_SIG_ENT_MAP].items(): no_tail_id = NO_TAIL_ID.sub("", ent_info[DEV_SIG_ENT_MAP_ID]) ha_entity_id = entity_registry.async_get_entity_id(platform, "zha", unique_id) assert ha_entity_id is not None assert ha_entity_id.startswith(no_tail_id) - test_ent_class = ent_info[DEV_SIG_ENT_MAP_CLASS] - test_unique_id_head = UNIQUE_ID_HD.match(unique_id).group(0) - assert (test_unique_id_head, test_ent_class) in ha_ent_info + entity = created_entities[ha_entity_id] + unhandled_entities.remove(ha_entity_id) - ha_comp, ha_unique_id, ha_cluster_handlers = ha_ent_info[ - (test_unique_id_head, test_ent_class) - ] - assert platform is ha_comp.value + assert entity.platform.domain == platform + assert type(entity).__name__ == ent_info[DEV_SIG_ENT_MAP_CLASS] # unique_id used for discover is the same for "multi entities" - assert unique_id.startswith(ha_unique_id) - assert {ch.name for ch in ha_cluster_handlers} == set( + assert unique_id == entity.unique_id + assert {ch.name for ch in entity.cluster_handlers.values()} == set( ent_info[DEV_SIG_CLUSTER_HANDLERS] ) - assert created_entity_count == len(device[DEV_SIG_ENT_MAP]) - - entity_ids = hass_disable_services.states.async_entity_ids() - await hass_disable_services.async_block_till_done() - - zha_entity_ids = { - ent - for ent in entity_ids - if not contains_ignored_suffix(ent) and ent.split(".")[0] in zha_const.PLATFORMS - } - assert zha_entity_ids == { - e[DEV_SIG_ENT_MAP_ID] for e in device[DEV_SIG_ENT_MAP].values() - } + # All unhandled entities should be ones we explicitly ignore + for entity_id in unhandled_entities: + domain = entity_id.split(".")[0] + assert domain in zha_const.PLATFORMS + assert contains_ignored_suffix(entity_id) def _get_first_identify_cluster(zigpy_device): diff --git a/tests/components/zha/test_update.py b/tests/components/zha/test_update.py index 894b5af9aba..29be109c673 100644 --- a/tests/components/zha/test_update.py +++ b/tests/components/zha/test_update.py @@ -1,10 +1,11 @@ """Test ZHA firmware updates.""" -from unittest.mock import AsyncMock, MagicMock, call, patch +from unittest.mock import AsyncMock, call, patch import pytest from zigpy.exceptions import DeliveryError -from zigpy.ota import CachedImage +from zigpy.ota import OtaImageWithMetadata import zigpy.ota.image as firmware +from zigpy.ota.providers import BaseOtaImageMetadata import zigpy.profiles.zha as zha import zigpy.types as t import zigpy.zcl.clusters.general as general @@ -21,17 +22,14 @@ from homeassistant.components.update import ( DOMAIN as UPDATE_DOMAIN, SERVICE_INSTALL, ) -from homeassistant.components.update.const import ATTR_SKIPPED_VERSION from homeassistant.const import ATTR_ENTITY_ID, STATE_OFF, STATE_ON, Platform -from homeassistant.core import HomeAssistant, State +from homeassistant.core import HomeAssistant from homeassistant.exceptions import HomeAssistantError from homeassistant.setup import async_setup_component from .common import async_enable_traffic, find_entity_id, update_attribute_cache from .conftest import SIG_EP_INPUT, SIG_EP_OUTPUT, SIG_EP_TYPE -from tests.common import mock_restore_cache_with_extra_data - @pytest.fixture(autouse=True) def update_platform_only(): @@ -80,26 +78,32 @@ async def setup_test_data( update_attribute_cache(cluster) # set up firmware image - fw_image = firmware.OTAImage() - fw_image.subelements = [firmware.SubElement(tag_id=0x0000, data=b"fw_image")] - fw_header = firmware.OTAImageHeader( - file_version=fw_version, - image_type=0x90, - manufacturer_id=zigpy_device.manufacturer_id, - upgrade_file_id=firmware.OTAImageHeader.MAGIC_VALUE, - header_version=256, - header_length=56, - field_control=0, - stack_version=2, - header_string="This is a test header!", - image_size=56 + 2 + 4 + 8, + fw_image = OtaImageWithMetadata( + metadata=BaseOtaImageMetadata( + file_version=fw_version, + manufacturer_id=0x1234, + image_type=0x90, + changelog="This is a test firmware image!", + ), + firmware=firmware.OTAImage( + header=firmware.OTAImageHeader( + upgrade_file_id=firmware.OTAImageHeader.MAGIC_VALUE, + file_version=fw_version, + image_type=0x90, + manufacturer_id=0x1234, + header_version=256, + header_length=56, + field_control=0, + stack_version=2, + header_string="This is a test header!", + image_size=56 + 2 + 4 + 8, + ), + subelements=[firmware.SubElement(tag_id=0x0000, data=b"fw_image")], + ), ) - fw_image.header = fw_header - fw_image.should_update = MagicMock(return_value=True) - cached_image = CachedImage(fw_image) cluster.endpoint.device.application.ota.get_ota_image = AsyncMock( - return_value=None if file_not_found else cached_image + return_value=None if file_not_found else fw_image ) zha_device = await zha_device_joined_restored(zigpy_device) @@ -108,18 +112,15 @@ async def setup_test_data( return zha_device, cluster, fw_image, installed_fw_version -@pytest.mark.parametrize("initial_version_unknown", (False, True)) async def test_firmware_update_notification_from_zigpy( hass: HomeAssistant, zha_device_joined_restored, zigpy_device, - initial_version_unknown, ) -> None: """Test ZHA update platform - firmware update notification.""" zha_device, cluster, fw_image, installed_fw_version = await setup_test_data( zha_device_joined_restored, zigpy_device, - skip_attribute_plugs=initial_version_unknown, ) entity_id = find_entity_id(Platform.UPDATE, zha_device, hass) @@ -132,12 +133,16 @@ async def test_firmware_update_notification_from_zigpy( # simulate an image available notification await cluster._handle_query_next_image( - fw_image.header.field_control, - zha_device.manufacturer_code, - fw_image.header.image_type, - installed_fw_version, - fw_image.header.header_version, - tsn=15, + foundation.ZCLHeader.cluster( + tsn=0x12, command_id=general.Ota.ServerCommandDefs.query_next_image.id + ), + general.QueryNextImageCommand( + fw_image.firmware.header.field_control, + zha_device.manufacturer_code, + fw_image.firmware.header.image_type, + installed_fw_version, + fw_image.firmware.header.header_version, + ), ) await hass.async_block_till_done() @@ -146,7 +151,9 @@ async def test_firmware_update_notification_from_zigpy( attrs = state.attributes assert attrs[ATTR_INSTALLED_VERSION] == f"0x{installed_fw_version:08x}" assert not attrs[ATTR_IN_PROGRESS] - assert attrs[ATTR_LATEST_VERSION] == f"0x{fw_image.header.file_version:08x}" + assert ( + attrs[ATTR_LATEST_VERSION] == f"0x{fw_image.firmware.header.file_version:08x}" + ) async def test_firmware_update_notification_from_service_call( @@ -167,35 +174,46 @@ async def test_firmware_update_notification_from_service_call( async def _async_image_notify_side_effect(*args, **kwargs): await cluster._handle_query_next_image( - fw_image.header.field_control, - zha_device.manufacturer_code, - fw_image.header.image_type, - installed_fw_version, - fw_image.header.header_version, - tsn=15, + foundation.ZCLHeader.cluster( + tsn=0x12, command_id=general.Ota.ServerCommandDefs.query_next_image.id + ), + general.QueryNextImageCommand( + fw_image.firmware.header.field_control, + zha_device.manufacturer_code, + fw_image.firmware.header.image_type, + installed_fw_version, + fw_image.firmware.header.header_version, + ), ) await async_setup_component(hass, HA_DOMAIN, {}) - cluster.image_notify = AsyncMock(side_effect=_async_image_notify_side_effect) - await hass.services.async_call( - HA_DOMAIN, - SERVICE_UPDATE_ENTITY, - service_data={ATTR_ENTITY_ID: entity_id}, - blocking=True, - ) - assert cluster.image_notify.await_count == 1 - assert cluster.image_notify.call_args_list[0] == call( - payload_type=cluster.ImageNotifyCommand.PayloadType.QueryJitter, - query_jitter=100, - ) + with patch( + "zigpy.ota.OTA.broadcast_notify", side_effect=_async_image_notify_side_effect + ): + await hass.services.async_call( + HA_DOMAIN, + SERVICE_UPDATE_ENTITY, + service_data={ATTR_ENTITY_ID: entity_id}, + blocking=True, + ) - await hass.async_block_till_done() - state = hass.states.get(entity_id) - assert state.state == STATE_ON - attrs = state.attributes - assert attrs[ATTR_INSTALLED_VERSION] == f"0x{installed_fw_version:08x}" - assert not attrs[ATTR_IN_PROGRESS] - assert attrs[ATTR_LATEST_VERSION] == f"0x{fw_image.header.file_version:08x}" + assert cluster.endpoint.device.application.ota.broadcast_notify.await_count == 1 + assert cluster.endpoint.device.application.ota.broadcast_notify.call_args_list[ + 0 + ] == call( + jitter=100, + ) + + await hass.async_block_till_done() + state = hass.states.get(entity_id) + assert state.state == STATE_ON + attrs = state.attributes + assert attrs[ATTR_INSTALLED_VERSION] == f"0x{installed_fw_version:08x}" + assert not attrs[ATTR_IN_PROGRESS] + assert ( + attrs[ATTR_LATEST_VERSION] + == f"0x{fw_image.firmware.header.file_version:08x}" + ) def make_packet(zigpy_device, cluster, cmd_name: str, **kwargs): @@ -226,6 +244,7 @@ def make_packet(zigpy_device, cluster, cmd_name: str, **kwargs): return ota_packet +@patch("zigpy.device.AFTER_OTA_ATTR_READ_DELAY", 0.01) async def test_firmware_update_success( hass: HomeAssistant, zha_device_joined_restored, zigpy_device ) -> None: @@ -234,6 +253,8 @@ async def test_firmware_update_success( zha_device_joined_restored, zigpy_device ) + assert installed_fw_version < fw_image.firmware.header.file_version + entity_id = find_entity_id(Platform.UPDATE, zha_device, hass) assert entity_id is not None @@ -244,12 +265,15 @@ async def test_firmware_update_success( # simulate an image available notification await cluster._handle_query_next_image( - fw_image.header.field_control, - zha_device.manufacturer_code, - fw_image.header.image_type, - installed_fw_version, - fw_image.header.header_version, - tsn=15, + foundation.ZCLHeader.cluster( + tsn=0x12, command_id=general.Ota.ServerCommandDefs.query_next_image.id + ), + general.QueryNextImageCommand( + field_control=fw_image.firmware.header.field_control, + manufacturer_code=zha_device.manufacturer_code, + image_type=fw_image.firmware.header.image_type, + current_file_version=installed_fw_version, + ), ) await hass.async_block_till_done() @@ -258,7 +282,9 @@ async def test_firmware_update_success( attrs = state.attributes assert attrs[ATTR_INSTALLED_VERSION] == f"0x{installed_fw_version:08x}" assert not attrs[ATTR_IN_PROGRESS] - assert attrs[ATTR_LATEST_VERSION] == f"0x{fw_image.header.file_version:08x}" + assert ( + attrs[ATTR_LATEST_VERSION] == f"0x{fw_image.firmware.header.file_version:08x}" + ) async def endpoint_reply(cluster_id, tsn, data, command_id): if cluster_id == general.Ota.cluster_id: @@ -270,9 +296,9 @@ async def test_firmware_update_success( cluster, general.Ota.ServerCommandDefs.query_next_image.name, field_control=general.Ota.QueryNextImageCommand.FieldControl.HardwareVersion, - manufacturer_code=fw_image.header.manufacturer_id, - image_type=fw_image.header.image_type, - current_file_version=fw_image.header.file_version - 10, + manufacturer_code=fw_image.firmware.header.manufacturer_id, + image_type=fw_image.firmware.header.image_type, + current_file_version=fw_image.firmware.header.file_version - 10, hardware_version=1, ) ) @@ -280,19 +306,19 @@ async def test_firmware_update_success( cmd, general.Ota.ClientCommandDefs.query_next_image_response.schema ): assert cmd.status == foundation.Status.SUCCESS - assert cmd.manufacturer_code == fw_image.header.manufacturer_id - assert cmd.image_type == fw_image.header.image_type - assert cmd.file_version == fw_image.header.file_version - assert cmd.image_size == fw_image.header.image_size + assert cmd.manufacturer_code == fw_image.firmware.header.manufacturer_id + assert cmd.image_type == fw_image.firmware.header.image_type + assert cmd.file_version == fw_image.firmware.header.file_version + assert cmd.image_size == fw_image.firmware.header.image_size zigpy_device.packet_received( make_packet( zigpy_device, cluster, general.Ota.ServerCommandDefs.image_block.name, field_control=general.Ota.ImageBlockCommand.FieldControl.RequestNodeAddr, - manufacturer_code=fw_image.header.manufacturer_id, - image_type=fw_image.header.image_type, - file_version=fw_image.header.file_version, + manufacturer_code=fw_image.firmware.header.manufacturer_id, + image_type=fw_image.firmware.header.image_type, + file_version=fw_image.firmware.header.file_version, file_offset=0, maximum_data_size=40, request_node_addr=zigpy_device.ieee, @@ -303,20 +329,23 @@ async def test_firmware_update_success( ): if cmd.file_offset == 0: assert cmd.status == foundation.Status.SUCCESS - assert cmd.manufacturer_code == fw_image.header.manufacturer_id - assert cmd.image_type == fw_image.header.image_type - assert cmd.file_version == fw_image.header.file_version + assert ( + cmd.manufacturer_code + == fw_image.firmware.header.manufacturer_id + ) + assert cmd.image_type == fw_image.firmware.header.image_type + assert cmd.file_version == fw_image.firmware.header.file_version assert cmd.file_offset == 0 - assert cmd.image_data == fw_image.serialize()[0:40] + assert cmd.image_data == fw_image.firmware.serialize()[0:40] zigpy_device.packet_received( make_packet( zigpy_device, cluster, general.Ota.ServerCommandDefs.image_block.name, field_control=general.Ota.ImageBlockCommand.FieldControl.RequestNodeAddr, - manufacturer_code=fw_image.header.manufacturer_id, - image_type=fw_image.header.image_type, - file_version=fw_image.header.file_version, + manufacturer_code=fw_image.firmware.header.manufacturer_id, + image_type=fw_image.firmware.header.image_type, + file_version=fw_image.firmware.header.file_version, file_offset=40, maximum_data_size=40, request_node_addr=zigpy_device.ieee, @@ -324,11 +353,14 @@ async def test_firmware_update_success( ) elif cmd.file_offset == 40: assert cmd.status == foundation.Status.SUCCESS - assert cmd.manufacturer_code == fw_image.header.manufacturer_id - assert cmd.image_type == fw_image.header.image_type - assert cmd.file_version == fw_image.header.file_version + assert ( + cmd.manufacturer_code + == fw_image.firmware.header.manufacturer_id + ) + assert cmd.image_type == fw_image.firmware.header.image_type + assert cmd.file_version == fw_image.firmware.header.file_version assert cmd.file_offset == 40 - assert cmd.image_data == fw_image.serialize()[40:70] + assert cmd.image_data == fw_image.firmware.serialize()[40:70] # make sure the state machine gets progress reports state = hass.states.get(entity_id) @@ -337,10 +369,10 @@ async def test_firmware_update_success( assert ( attrs[ATTR_INSTALLED_VERSION] == f"0x{installed_fw_version:08x}" ) - assert attrs[ATTR_IN_PROGRESS] == 57 + assert attrs[ATTR_IN_PROGRESS] == 58 assert ( attrs[ATTR_LATEST_VERSION] - == f"0x{fw_image.header.file_version:08x}" + == f"0x{fw_image.firmware.header.file_version:08x}" ) zigpy_device.packet_received( @@ -349,21 +381,34 @@ async def test_firmware_update_success( cluster, general.Ota.ServerCommandDefs.upgrade_end.name, status=foundation.Status.SUCCESS, - manufacturer_code=fw_image.header.manufacturer_id, - image_type=fw_image.header.image_type, - file_version=fw_image.header.file_version, + manufacturer_code=fw_image.firmware.header.manufacturer_id, + image_type=fw_image.firmware.header.image_type, + file_version=fw_image.firmware.header.file_version, ) ) elif isinstance( cmd, general.Ota.ClientCommandDefs.upgrade_end_response.schema ): - assert cmd.manufacturer_code == fw_image.header.manufacturer_id - assert cmd.image_type == fw_image.header.image_type - assert cmd.file_version == fw_image.header.file_version + assert cmd.manufacturer_code == fw_image.firmware.header.manufacturer_id + assert cmd.image_type == fw_image.firmware.header.image_type + assert cmd.file_version == fw_image.firmware.header.file_version assert cmd.current_time == 0 assert cmd.upgrade_time == 0 + def read_new_fw_version(*args, **kwargs): + cluster.update_attribute( + attrid=general.Ota.AttributeDefs.current_file_version.id, + value=fw_image.firmware.header.file_version, + ) + return { + general.Ota.AttributeDefs.current_file_version.id: ( + fw_image.firmware.header.file_version + ) + }, {} + + cluster.read_attributes.side_effect = read_new_fw_version + cluster.endpoint.reply = AsyncMock(side_effect=endpoint_reply) await hass.services.async_call( UPDATE_DOMAIN, @@ -377,10 +422,21 @@ async def test_firmware_update_success( state = hass.states.get(entity_id) assert state.state == STATE_OFF attrs = state.attributes - assert attrs[ATTR_INSTALLED_VERSION] == f"0x{fw_image.header.file_version:08x}" + assert ( + attrs[ATTR_INSTALLED_VERSION] + == f"0x{fw_image.firmware.header.file_version:08x}" + ) assert not attrs[ATTR_IN_PROGRESS] assert attrs[ATTR_LATEST_VERSION] == attrs[ATTR_INSTALLED_VERSION] + # If we send a progress notification incorrectly, it won't be handled + entity = hass.data[UPDATE_DOMAIN].get_entity(entity_id) + entity._update_progress(50, 100, 0.50) + + state = hass.states.get(entity_id) + assert not attrs[ATTR_IN_PROGRESS] + assert state.state == STATE_OFF + async def test_firmware_update_raises( hass: HomeAssistant, zha_device_joined_restored, zigpy_device @@ -400,12 +456,16 @@ async def test_firmware_update_raises( # simulate an image available notification await cluster._handle_query_next_image( - fw_image.header.field_control, - zha_device.manufacturer_code, - fw_image.header.image_type, - installed_fw_version, - fw_image.header.header_version, - tsn=15, + foundation.ZCLHeader.cluster( + tsn=0x12, command_id=general.Ota.ServerCommandDefs.query_next_image.id + ), + general.QueryNextImageCommand( + fw_image.firmware.header.field_control, + zha_device.manufacturer_code, + fw_image.firmware.header.image_type, + installed_fw_version, + fw_image.firmware.header.header_version, + ), ) await hass.async_block_till_done() @@ -414,7 +474,9 @@ async def test_firmware_update_raises( attrs = state.attributes assert attrs[ATTR_INSTALLED_VERSION] == f"0x{installed_fw_version:08x}" assert not attrs[ATTR_IN_PROGRESS] - assert attrs[ATTR_LATEST_VERSION] == f"0x{fw_image.header.file_version:08x}" + assert ( + attrs[ATTR_LATEST_VERSION] == f"0x{fw_image.firmware.header.file_version:08x}" + ) async def endpoint_reply(cluster_id, tsn, data, command_id): if cluster_id == general.Ota.cluster_id: @@ -426,9 +488,9 @@ async def test_firmware_update_raises( cluster, general.Ota.ServerCommandDefs.query_next_image.name, field_control=general.Ota.QueryNextImageCommand.FieldControl.HardwareVersion, - manufacturer_code=fw_image.header.manufacturer_id, - image_type=fw_image.header.image_type, - current_file_version=fw_image.header.file_version - 10, + manufacturer_code=fw_image.firmware.header.manufacturer_id, + image_type=fw_image.firmware.header.image_type, + current_file_version=fw_image.firmware.header.file_version - 10, hardware_version=1, ) ) @@ -436,10 +498,10 @@ async def test_firmware_update_raises( cmd, general.Ota.ClientCommandDefs.query_next_image_response.schema ): assert cmd.status == foundation.Status.SUCCESS - assert cmd.manufacturer_code == fw_image.header.manufacturer_id - assert cmd.image_type == fw_image.header.image_type - assert cmd.file_version == fw_image.header.file_version - assert cmd.image_size == fw_image.header.image_size + assert cmd.manufacturer_code == fw_image.firmware.header.manufacturer_id + assert cmd.image_type == fw_image.firmware.header.image_type + assert cmd.file_version == fw_image.firmware.header.file_version + assert cmd.image_size == fw_image.firmware.header.image_size raise DeliveryError("failed to deliver") cluster.endpoint.reply = AsyncMock(side_effect=endpoint_reply) @@ -467,29 +529,10 @@ async def test_firmware_update_raises( ) -async def test_firmware_update_restore_data( +async def test_firmware_update_no_longer_compatible( hass: HomeAssistant, zha_device_joined_restored, zigpy_device ) -> None: - """Test ZHA update platform - restore data.""" - fw_version = 0x12345678 - installed_fw_version = fw_version - 10 - mock_restore_cache_with_extra_data( - hass, - [ - ( - State( - "update.fakemanufacturer_fakemodel_firmware", - STATE_ON, - { - ATTR_INSTALLED_VERSION: f"0x{installed_fw_version:08x}", - ATTR_LATEST_VERSION: f"0x{fw_version:08x}", - ATTR_SKIPPED_VERSION: None, - }, - ), - {"image_type": 0x90}, - ) - ], - ) + """Test ZHA update platform - firmware update is no longer valid.""" zha_device, cluster, fw_image, installed_fw_version = await setup_test_data( zha_device_joined_restored, zigpy_device ) @@ -500,94 +543,67 @@ async def test_firmware_update_restore_data( # allow traffic to flow through the gateway and device await async_enable_traffic(hass, [zha_device]) + assert hass.states.get(entity_id).state == STATE_OFF + + # simulate an image available notification + await cluster._handle_query_next_image( + foundation.ZCLHeader.cluster( + tsn=0x12, command_id=general.Ota.ServerCommandDefs.query_next_image.id + ), + general.QueryNextImageCommand( + fw_image.firmware.header.field_control, + zha_device.manufacturer_code, + fw_image.firmware.header.image_type, + installed_fw_version, + fw_image.firmware.header.header_version, + ), + ) + + await hass.async_block_till_done() state = hass.states.get(entity_id) assert state.state == STATE_ON attrs = state.attributes assert attrs[ATTR_INSTALLED_VERSION] == f"0x{installed_fw_version:08x}" assert not attrs[ATTR_IN_PROGRESS] - assert attrs[ATTR_LATEST_VERSION] == f"0x{fw_image.header.file_version:08x}" - - -async def test_firmware_update_restore_file_not_found( - hass: HomeAssistant, zha_device_joined_restored, zigpy_device -) -> None: - """Test ZHA update platform - restore data - file not found.""" - fw_version = 0x12345678 - installed_fw_version = fw_version - 10 - mock_restore_cache_with_extra_data( - hass, - [ - ( - State( - "update.fakemanufacturer_fakemodel_firmware", - STATE_ON, - { - ATTR_INSTALLED_VERSION: f"0x{installed_fw_version:08x}", - ATTR_LATEST_VERSION: f"0x{fw_version:08x}", - ATTR_SKIPPED_VERSION: None, - }, - ), - {"image_type": 0x90}, - ) - ], - ) - zha_device, cluster, fw_image, installed_fw_version = await setup_test_data( - zha_device_joined_restored, zigpy_device, file_not_found=True + assert ( + attrs[ATTR_LATEST_VERSION] == f"0x{fw_image.firmware.header.file_version:08x}" ) - entity_id = find_entity_id(Platform.UPDATE, zha_device, hass) - assert entity_id is not None + new_version = 0x99999999 - # allow traffic to flow through the gateway and device - await async_enable_traffic(hass, [zha_device]) + async def endpoint_reply(cluster_id, tsn, data, command_id): + if cluster_id == general.Ota.cluster_id: + hdr, cmd = cluster.deserialize(data) + if isinstance(cmd, general.Ota.ImageNotifyCommand): + zigpy_device.packet_received( + make_packet( + zigpy_device, + cluster, + general.Ota.ServerCommandDefs.query_next_image.name, + field_control=general.Ota.QueryNextImageCommand.FieldControl.HardwareVersion, + manufacturer_code=fw_image.firmware.header.manufacturer_id, + image_type=fw_image.firmware.header.image_type, + # The device reports that it is no longer compatible! + current_file_version=new_version, + hardware_version=1, + ) + ) + cluster.endpoint.reply = AsyncMock(side_effect=endpoint_reply) + with pytest.raises(HomeAssistantError): + await hass.services.async_call( + UPDATE_DOMAIN, + SERVICE_INSTALL, + { + ATTR_ENTITY_ID: entity_id, + }, + blocking=True, + ) + + # We updated the currently installed firmware version, as it is no longer valid state = hass.states.get(entity_id) assert state.state == STATE_OFF attrs = state.attributes - assert attrs[ATTR_INSTALLED_VERSION] == f"0x{installed_fw_version:08x}" + assert attrs[ATTR_INSTALLED_VERSION] == f"0x{new_version:08x}" assert not attrs[ATTR_IN_PROGRESS] - assert attrs[ATTR_LATEST_VERSION] == f"0x{installed_fw_version:08x}" - - -async def test_firmware_update_restore_version_from_state_machine( - hass: HomeAssistant, zha_device_joined_restored, zigpy_device -) -> None: - """Test ZHA update platform - restore data - file not found.""" - fw_version = 0x12345678 - installed_fw_version = fw_version - 10 - mock_restore_cache_with_extra_data( - hass, - [ - ( - State( - "update.fakemanufacturer_fakemodel_firmware", - STATE_ON, - { - ATTR_INSTALLED_VERSION: f"0x{installed_fw_version:08x}", - ATTR_LATEST_VERSION: f"0x{fw_version:08x}", - ATTR_SKIPPED_VERSION: None, - }, - ), - {"image_type": 0x90}, - ) - ], - ) - zha_device, cluster, fw_image, installed_fw_version = await setup_test_data( - zha_device_joined_restored, - zigpy_device, - skip_attribute_plugs=True, - file_not_found=True, - ) - - entity_id = find_entity_id(Platform.UPDATE, zha_device, hass) - assert entity_id is not None - - # allow traffic to flow through the gateway and device - await async_enable_traffic(hass, [zha_device]) - - state = hass.states.get(entity_id) - assert state.state == STATE_OFF - attrs = state.attributes - assert attrs[ATTR_INSTALLED_VERSION] == f"0x{installed_fw_version:08x}" - assert not attrs[ATTR_IN_PROGRESS] - assert attrs[ATTR_LATEST_VERSION] == f"0x{installed_fw_version:08x}" + assert attrs[ATTR_LATEST_VERSION] == f"0x{new_version:08x}"