Use new zigpy OTA providers for ZHA (#111159)

* Use `None` instead of `"unknown"` when the current version is unknown

* Only use the current file version from the OTA notification

* Use `sw_version`, if available, and update `current_file_version`

* Assume the current version is the latest version

* Fix lint errors

* Use `image` instead of `firmware`

* Include a changelog if updates expose it

* Clear latest firmware only after updating the installed version

* Bump minimum zigpy version to 0.63.0

* Create a data update coordinator to consolidate updates

* Fix overridden `async_update`

* Fix most unit tests

* Simplify `test_devices` to fix current tests

* Use a dict comprehension for creating mocked entities

* Fix unit tests (thanks @dmulcahey!)

* Update the currently installed version on cluster attribute update

* Drop `PARALLEL_UPDATES` now that we use an update coordinator

* Drop `_reset_progress`, it is already handled by the update component

* Do not update the progress if we are not supposed to be updating

* Ignore latest version (e.g. if device attrs changed) if zigpy rejects it

* Clean up handling of command id in `Ota.cluster_command`

* Start progress at 1%: 0 and False are considered equal and are filtered!

Use `ceil` instead of remapping 1-100

* The installed version will be auto-updated when the upgrade succeeds

* Avoid 1 as well, it collides with `True`

* Bump zigpy to (unreleased) 0.63.2

* Fix unit tests

* Fix existing unit tests

Send both event types

Globally enable sending both event types

* Remove unnecessary branches

* Test ignoring invalid progress callbacks

* Test updating a device with a no longer compatible firmware
This commit is contained in:
puddly 2024-02-28 14:38:04 -05:00 committed by GitHub
parent 4895f92551
commit 4ec75d6ca7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 388 additions and 423 deletions

View File

@ -627,8 +627,9 @@ class ClientClusterHandler(ClusterHandler):
"""ClusterHandler for Zigbee client (output) clusters."""
@callback
def attribute_updated(self, attrid: int, value: Any, _: Any) -> None:
def attribute_updated(self, attrid: int, value: Any, timestamp: Any) -> None:
"""Handle an attribute updated on this cluster."""
super().attribute_updated(attrid, value, timestamp)
try:
attr_name = self._cluster.attributes[attrid].name

View File

@ -56,7 +56,6 @@ from ..const import (
SIGNAL_MOVE_LEVEL,
SIGNAL_SET_LEVEL,
SIGNAL_UPDATE_DEVICE,
UNKNOWN as ZHA_UNKNOWN,
)
from . import (
AttrReportConfig,
@ -538,14 +537,9 @@ class OtaClusterHandler(ClusterHandler):
}
@property
def current_file_version(self) -> str:
def current_file_version(self) -> int | None:
"""Return cached value of current_file_version attribute."""
current_file_version = self.cluster.get(
Ota.AttributeDefs.current_file_version.name
)
if current_file_version is not None:
return f"0x{int(current_file_version):08x}"
return ZHA_UNKNOWN
return self.cluster.get(Ota.AttributeDefs.current_file_version.name)
@registries.CLIENT_CLUSTER_HANDLER_REGISTRY.register(Ota.cluster_id)
@ -559,36 +553,31 @@ class OtaClientClusterHandler(ClientClusterHandler):
}
@property
def current_file_version(self) -> str:
def current_file_version(self) -> int | None:
"""Return cached value of current_file_version attribute."""
current_file_version = self.cluster.get(
Ota.AttributeDefs.current_file_version.name
)
if current_file_version is not None:
return f"0x{int(current_file_version):08x}"
return ZHA_UNKNOWN
return self.cluster.get(Ota.AttributeDefs.current_file_version.name)
@callback
def cluster_command(
self, tsn: int, command_id: int, args: list[Any] | None
) -> None:
"""Handle OTA commands."""
if command_id in self.cluster.server_commands:
cmd_name = self.cluster.server_commands[command_id].name
else:
cmd_name = command_id
if command_id not in self.cluster.server_commands:
return
signal_id = self._endpoint.unique_id.split("-")[0]
cmd_name = self.cluster.server_commands[command_id].name
if cmd_name == Ota.ServerCommandDefs.query_next_image.name:
assert args
self.async_send_signal(SIGNAL_UPDATE_DEVICE.format(signal_id), args[3])
async def async_check_for_update(self):
"""Check for firmware availability by issuing an image notify command."""
await self.cluster.image_notify(
payload_type=(self.cluster.ImageNotifyCommand.PayloadType.QueryJitter),
query_jitter=100,
)
current_file_version = args[3]
self.cluster.update_attribute(
Ota.AttributeDefs.current_file_version.id, current_file_version
)
self.async_send_signal(
SIGNAL_UPDATE_DEVICE.format(signal_id), current_file_version
)
@registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.register(Partition.cluster_id)

View File

@ -75,11 +75,12 @@ async def async_add_entities(
tuple[str, ZHADevice, list[ClusterHandler]],
]
],
**kwargs,
) -> None:
"""Add entities helper."""
if not entities:
return
to_add = [ent_cls.create_entity(*args) for ent_cls, args in entities]
to_add = [ent_cls.create_entity(*args, **kwargs) for ent_cls, args in entities]
entities_to_add = [entity for entity in to_add if entity is not None]
_async_add_entities(entities_to_add, update_before_add=False)
entities.clear()

View File

@ -27,7 +27,7 @@
"pyserial-asyncio==0.6",
"zha-quirks==0.0.112",
"zigpy-deconz==0.23.1",
"zigpy==0.62.3",
"zigpy==0.63.2",
"zigpy-xbee==0.20.1",
"zigpy-zigate==0.12.0",
"zigpy-znp==0.12.1",

View File

@ -1,17 +1,16 @@
"""Representation of ZHA updates."""
from __future__ import annotations
from dataclasses import dataclass
import functools
import logging
import math
from typing import TYPE_CHECKING, Any
from zigpy.ota.image import BaseOTAImage
from zigpy.types import uint16_t
from zigpy.ota import OtaImageWithMetadata
from zigpy.zcl.clusters.general import Ota
from zigpy.zcl.foundation import Status
from homeassistant.components.update import (
ATTR_INSTALLED_VERSION,
ATTR_LATEST_VERSION,
UpdateDeviceClass,
UpdateEntity,
UpdateEntityFeature,
@ -22,36 +21,29 @@ from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.restore_state import ExtraStoredData
from homeassistant.helpers.update_coordinator import (
CoordinatorEntity,
DataUpdateCoordinator,
)
from .core import discovery
from .core.const import CLUSTER_HANDLER_OTA, SIGNAL_ADD_ENTITIES, UNKNOWN
from .core.helpers import get_zha_data
from .core.const import CLUSTER_HANDLER_OTA, SIGNAL_ADD_ENTITIES, SIGNAL_ATTR_UPDATED
from .core.helpers import get_zha_data, get_zha_gateway
from .core.registries import ZHA_ENTITIES
from .entity import ZhaEntity
if TYPE_CHECKING:
from zigpy.application import ControllerApplication
from .core.cluster_handlers import ClusterHandler
from .core.device import ZHADevice
_LOGGER = logging.getLogger(__name__)
CONFIG_DIAGNOSTIC_MATCH = functools.partial(
ZHA_ENTITIES.config_diagnostic_match, Platform.UPDATE
)
# don't let homeassistant check for updates button hammer the zigbee network
PARALLEL_UPDATES = 1
@dataclass
class ZHAFirmwareUpdateExtraStoredData(ExtraStoredData):
"""Extra stored data for ZHA firmware update entity."""
image_type: uint16_t | None
def as_dict(self) -> dict[str, Any]:
"""Return a dict representation of the extra data."""
return {"image_type": self.image_type}
async def async_setup_entry(
hass: HomeAssistant,
@ -62,18 +54,46 @@ async def async_setup_entry(
zha_data = get_zha_data(hass)
entities_to_create = zha_data.platforms[Platform.UPDATE]
coordinator = ZHAFirmwareUpdateCoordinator(
hass, get_zha_gateway(hass).application_controller
)
unsub = async_dispatcher_connect(
hass,
SIGNAL_ADD_ENTITIES,
functools.partial(
discovery.async_add_entities, async_add_entities, entities_to_create
discovery.async_add_entities,
async_add_entities,
entities_to_create,
coordinator=coordinator,
),
)
config_entry.async_on_unload(unsub)
class ZHAFirmwareUpdateCoordinator(DataUpdateCoordinator): # pylint: disable=hass-enforce-coordinator-module
"""Firmware update coordinator that broadcasts updates network-wide."""
def __init__(
self, hass: HomeAssistant, controller_application: ControllerApplication
) -> None:
"""Initialize the coordinator."""
super().__init__(
hass,
_LOGGER,
name="ZHA firmware update coordinator",
update_method=self.async_update_data,
)
self.controller_application = controller_application
async def async_update_data(self) -> None:
"""Fetch the latest firmware update data."""
# Broadcast to all devices
await self.controller_application.ota.broadcast_notify(jitter=100)
@CONFIG_DIAGNOSTIC_MATCH(cluster_handler_names=CLUSTER_HANDLER_OTA)
class ZHAFirmwareUpdateEntity(ZhaEntity, UpdateEntity):
class ZHAFirmwareUpdateEntity(ZhaEntity, CoordinatorEntity, UpdateEntity):
"""Representation of a ZHA firmware update entity."""
_unique_id_suffix = "firmware_update"
@ -90,147 +110,114 @@ class ZHAFirmwareUpdateEntity(ZhaEntity, UpdateEntity):
unique_id: str,
zha_device: ZHADevice,
channels: list[ClusterHandler],
coordinator: ZHAFirmwareUpdateCoordinator,
**kwargs: Any,
) -> None:
"""Initialize the ZHA update entity."""
super().__init__(unique_id, zha_device, channels, **kwargs)
CoordinatorEntity.__init__(self, coordinator)
self._ota_cluster_handler: ClusterHandler = self.cluster_handlers[
CLUSTER_HANDLER_OTA
]
self._attr_installed_version: str = self.determine_installed_version()
self._image_type: uint16_t | None = None
self._latest_version_firmware: BaseOTAImage | None = None
self._result = None
self._attr_installed_version: str | None = self._get_cluster_version()
self._attr_latest_version = self._attr_installed_version
self._latest_firmware: OtaImageWithMetadata | None = None
def _get_cluster_version(self) -> str | None:
"""Synchronize current file version with the cluster."""
device = self._ota_cluster_handler._endpoint.device # pylint: disable=protected-access
if self._ota_cluster_handler.current_file_version is not None:
return f"0x{self._ota_cluster_handler.current_file_version:08x}"
if device.sw_version is not None:
return device.sw_version
return None
@callback
def determine_installed_version(self) -> str:
"""Determine the currently installed firmware version."""
currently_installed_version = self._ota_cluster_handler.current_file_version
version_from_dr = self.zha_device.sw_version
if currently_installed_version == UNKNOWN and version_from_dr:
currently_installed_version = version_from_dr
return currently_installed_version
@property
def extra_restore_state_data(self) -> ZHAFirmwareUpdateExtraStoredData:
"""Return ZHA firmware update specific state data to be restored."""
return ZHAFirmwareUpdateExtraStoredData(self._image_type)
def attribute_updated(self, attrid: int, name: str, value: Any) -> None:
"""Handle attribute updates on the OTA cluster."""
if attrid == Ota.AttributeDefs.current_file_version.id:
self._attr_installed_version = f"0x{value:08x}"
self.async_write_ha_state()
@callback
def device_ota_update_available(self, image: BaseOTAImage) -> None:
def device_ota_update_available(
self, image: OtaImageWithMetadata, current_file_version: int
) -> None:
"""Handle ota update available signal from Zigpy."""
self._latest_version_firmware = image
self._attr_latest_version = f"0x{image.header.file_version:08x}"
self._image_type = image.header.image_type
self._attr_installed_version = self.determine_installed_version()
self._latest_firmware = image
self._attr_latest_version = f"0x{image.version:08x}"
self._attr_installed_version = f"0x{current_file_version:08x}"
if image.metadata.changelog:
self._attr_release_summary = image.metadata.changelog
self.async_write_ha_state()
@callback
def _update_progress(self, current: int, total: int, progress: float) -> None:
"""Update install progress on event."""
assert self._latest_version_firmware
self._attr_in_progress = int(progress)
# If we are not supposed to be updating, do nothing
if self._attr_in_progress is False:
return
# Remap progress to 2-100 to avoid 0 and 1
self._attr_in_progress = int(math.ceil(2 + 98 * progress / 100))
self.async_write_ha_state()
@callback
def _reset_progress(self, write_state: bool = True) -> None:
"""Reset update install progress."""
self._result = None
self._attr_in_progress = False
if write_state:
self.async_write_ha_state()
async def async_update(self) -> None:
"""Handle the update entity service call to manually check for available firmware updates."""
await super().async_update()
# check for updates in the HA settings menu can invoke this so we need to check if the device
# is mains powered so we don't get a ton of errors in the logs from sleepy devices.
if self.zha_device.available and self.zha_device.is_mains_powered:
await self._ota_cluster_handler.async_check_for_update()
async def async_install(
self, version: str | None, backup: bool, **kwargs: Any
) -> None:
"""Install an update."""
firmware = self._latest_version_firmware
assert firmware
self._reset_progress(False)
assert self._latest_firmware is not None
# Set the progress to an indeterminate state
self._attr_in_progress = True
self.async_write_ha_state()
try:
self._result = await self.zha_device.device.update_firmware(
self._latest_version_firmware,
self._update_progress,
result = await self.zha_device.device.update_firmware(
image=self._latest_firmware,
progress_callback=self._update_progress,
)
except Exception as ex:
self._reset_progress()
raise HomeAssistantError(ex) from ex
raise HomeAssistantError(f"Update was not successful: {ex}") from ex
assert self._result is not None
# If we tried to install firmware that is no longer compatible with the device,
# bail out
if result == Status.NO_IMAGE_AVAILABLE:
self._attr_latest_version = self._attr_installed_version
self.async_write_ha_state()
# If the update was not successful, we should throw an error to let the user know
if self._result != Status.SUCCESS:
# save result since reset_progress will clear it
results = self._result
self._reset_progress()
raise HomeAssistantError(f"Update was not successful - result: {results}")
# If the update finished but was not successful, we should also throw an error
if result != Status.SUCCESS:
raise HomeAssistantError(f"Update was not successful: {result}")
# If we get here, all files were installed successfully
self._attr_installed_version = (
self._attr_latest_version
) = f"0x{firmware.header.file_version:08x}"
self._latest_version_firmware = None
self._reset_progress()
# Clear the state
self._latest_firmware = None
self._attr_in_progress = False
self.async_write_ha_state()
async def async_added_to_hass(self) -> None:
"""Call when entity is added."""
await super().async_added_to_hass()
last_state = await self.async_get_last_state()
# If we have a complete previous state, use that to set the installed version
if (
last_state
and self._attr_installed_version == UNKNOWN
and (installed_version := last_state.attributes.get(ATTR_INSTALLED_VERSION))
):
self._attr_installed_version = installed_version
# If we have a complete previous state, use that to set the latest version
if (
last_state
and (latest_version := last_state.attributes.get(ATTR_LATEST_VERSION))
is not None
and latest_version != UNKNOWN
):
self._attr_latest_version = latest_version
# If we have no state or latest version to restore, or the latest version is
# the same as the installed version, we can set the latest
# version to installed so that the entity starts as off.
elif (
not last_state
or not latest_version
or latest_version == self._attr_installed_version
):
self._attr_latest_version = self._attr_installed_version
if self._attr_latest_version != self._attr_installed_version and (
extra_data := await self.async_get_last_extra_data()
):
self._image_type = extra_data.as_dict()["image_type"]
if self._image_type:
self._latest_version_firmware = (
await self.zha_device.device.application.ota.get_ota_image(
self.zha_device.manufacturer_code, self._image_type
)
)
# if we can't locate an image but we have a latest version that differs
# we should set the latest version to the installed version to avoid
# confusion and errors
if not self._latest_version_firmware:
self._attr_latest_version = self._attr_installed_version
# OTA events are sent by the device
self.zha_device.device.add_listener(self)
self.async_accept_signal(
self._ota_cluster_handler, SIGNAL_ATTR_UPDATED, self.attribute_updated
)
async def async_will_remove_from_hass(self) -> None:
"""Call when entity will be removed."""
await super().async_will_remove_from_hass()
self._reset_progress(False)
self._attr_in_progress = False
async def async_update(self) -> None:
"""Update the entity."""
await CoordinatorEntity.async_update(self)
await super().async_update()

View File

@ -2950,7 +2950,7 @@ zigpy-zigate==0.12.0
zigpy-znp==0.12.1
# homeassistant.components.zha
zigpy==0.62.3
zigpy==0.63.2
# homeassistant.components.zoneminder
zm-py==0.5.4

View File

@ -2270,7 +2270,7 @@ zigpy-zigate==0.12.0
zigpy-znp==0.12.1
# homeassistant.components.zha
zigpy==0.62.3
zigpy==0.63.2
# homeassistant.components.zwave_js
zwave-js-server-python==0.55.3

View File

@ -156,12 +156,7 @@ async def zigpy_app_controller():
zigpy.config.CONF_NWK_BACKUP_ENABLED: False,
zigpy.config.CONF_TOPO_SCAN_ENABLED: False,
zigpy.config.CONF_OTA: {
zigpy.config.CONF_OTA_IKEA: False,
zigpy.config.CONF_OTA_INOVELLI: False,
zigpy.config.CONF_OTA_LEDVANCE: False,
zigpy.config.CONF_OTA_SALUS: False,
zigpy.config.CONF_OTA_SONOFF: False,
zigpy.config.CONF_OTA_THIRDREALITY: False,
zigpy.config.CONF_OTA_ENABLED: False,
},
}
)

View File

@ -24,7 +24,8 @@ from homeassistant.components.zha.core.helpers import get_zha_gateway
import homeassistant.components.zha.core.registries as zha_regs
from homeassistant.const import Platform
from homeassistant.core import HomeAssistant
import homeassistant.helpers.entity_registry as er
from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.entity_platform import EntityPlatform
from .conftest import SIG_EP_INPUT, SIG_EP_OUTPUT, SIG_EP_PROFILE, SIG_EP_TYPE
from .zha_devices_list import (
@ -80,8 +81,6 @@ async def test_devices(
zha_device_joined_restored,
) -> None:
"""Test device discovery."""
entity_registry = er.async_get(hass_disable_services)
zigpy_device = zigpy_device_mock(
endpoints=device[SIG_ENDPOINTS],
ieee="00:11:22:33:44:55:66:77",
@ -96,14 +95,13 @@ async def test_devices(
if cluster_identify:
cluster_identify.request.reset_mock()
orig_new_entity = Endpoint.async_new_entity
_dispatch = mock.MagicMock(wraps=orig_new_entity)
try:
Endpoint.async_new_entity = lambda *a, **kw: _dispatch(*a, **kw)
with patch(
"homeassistant.helpers.entity_platform.EntityPlatform._async_schedule_add_entities_for_entry",
side_effect=EntityPlatform._async_schedule_add_entities_for_entry,
autospec=True,
) as mock_add_entities:
zha_dev = await zha_device_joined_restored(zigpy_device)
await hass_disable_services.async_block_till_done()
finally:
Endpoint.async_new_entity = orig_new_entity
if cluster_identify:
# We only identify on join
@ -136,60 +134,38 @@ async def test_devices(
for ch in endpoint.client_cluster_handlers.values()
}
assert event_cluster_handlers == set(device[DEV_SIG_EVT_CLUSTER_HANDLERS])
# we need to probe the class create entity factory so we need to reset this to get accurate results
zha_regs.ZHA_ENTITIES.clean_up()
# build a dict of entity_class -> (platform, unique_id, cluster_handlers) tuple
ha_ent_info = {}
created_entity_count = 0
for call in _dispatch.call_args_list:
_, platform, entity_cls, unique_id, cluster_handlers = call[0]
# the factory can return None. We filter these out to get an accurate created entity count
response = entity_cls.create_entity(unique_id, zha_dev, cluster_handlers)
if response and not contains_ignored_suffix(response.unique_id):
created_entity_count += 1
unique_id_head = UNIQUE_ID_HD.match(unique_id).group(
0
) # ieee + endpoint_id
ha_ent_info[(unique_id_head, entity_cls.__name__)] = (
platform,
unique_id,
cluster_handlers,
)
for comp_id, ent_info in device[DEV_SIG_ENT_MAP].items():
platform, unique_id = comp_id
# Keep track of unhandled entities: they should always be ones we explicitly ignore
created_entities = {
entity.entity_id: entity
for mock_call in mock_add_entities.mock_calls
for entity in mock_call.args[1]
}
unhandled_entities = set(created_entities.keys())
entity_registry = er.async_get(hass_disable_services)
for (platform, unique_id), ent_info in device[DEV_SIG_ENT_MAP].items():
no_tail_id = NO_TAIL_ID.sub("", ent_info[DEV_SIG_ENT_MAP_ID])
ha_entity_id = entity_registry.async_get_entity_id(platform, "zha", unique_id)
assert ha_entity_id is not None
assert ha_entity_id.startswith(no_tail_id)
test_ent_class = ent_info[DEV_SIG_ENT_MAP_CLASS]
test_unique_id_head = UNIQUE_ID_HD.match(unique_id).group(0)
assert (test_unique_id_head, test_ent_class) in ha_ent_info
entity = created_entities[ha_entity_id]
unhandled_entities.remove(ha_entity_id)
ha_comp, ha_unique_id, ha_cluster_handlers = ha_ent_info[
(test_unique_id_head, test_ent_class)
]
assert platform is ha_comp.value
assert entity.platform.domain == platform
assert type(entity).__name__ == ent_info[DEV_SIG_ENT_MAP_CLASS]
# unique_id used for discover is the same for "multi entities"
assert unique_id.startswith(ha_unique_id)
assert {ch.name for ch in ha_cluster_handlers} == set(
assert unique_id == entity.unique_id
assert {ch.name for ch in entity.cluster_handlers.values()} == set(
ent_info[DEV_SIG_CLUSTER_HANDLERS]
)
assert created_entity_count == len(device[DEV_SIG_ENT_MAP])
entity_ids = hass_disable_services.states.async_entity_ids()
await hass_disable_services.async_block_till_done()
zha_entity_ids = {
ent
for ent in entity_ids
if not contains_ignored_suffix(ent) and ent.split(".")[0] in zha_const.PLATFORMS
}
assert zha_entity_ids == {
e[DEV_SIG_ENT_MAP_ID] for e in device[DEV_SIG_ENT_MAP].values()
}
# All unhandled entities should be ones we explicitly ignore
for entity_id in unhandled_entities:
domain = entity_id.split(".")[0]
assert domain in zha_const.PLATFORMS
assert contains_ignored_suffix(entity_id)
def _get_first_identify_cluster(zigpy_device):

View File

@ -1,10 +1,11 @@
"""Test ZHA firmware updates."""
from unittest.mock import AsyncMock, MagicMock, call, patch
from unittest.mock import AsyncMock, call, patch
import pytest
from zigpy.exceptions import DeliveryError
from zigpy.ota import CachedImage
from zigpy.ota import OtaImageWithMetadata
import zigpy.ota.image as firmware
from zigpy.ota.providers import BaseOtaImageMetadata
import zigpy.profiles.zha as zha
import zigpy.types as t
import zigpy.zcl.clusters.general as general
@ -21,17 +22,14 @@ from homeassistant.components.update import (
DOMAIN as UPDATE_DOMAIN,
SERVICE_INSTALL,
)
from homeassistant.components.update.const import ATTR_SKIPPED_VERSION
from homeassistant.const import ATTR_ENTITY_ID, STATE_OFF, STATE_ON, Platform
from homeassistant.core import HomeAssistant, State
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.setup import async_setup_component
from .common import async_enable_traffic, find_entity_id, update_attribute_cache
from .conftest import SIG_EP_INPUT, SIG_EP_OUTPUT, SIG_EP_TYPE
from tests.common import mock_restore_cache_with_extra_data
@pytest.fixture(autouse=True)
def update_platform_only():
@ -80,26 +78,32 @@ async def setup_test_data(
update_attribute_cache(cluster)
# set up firmware image
fw_image = firmware.OTAImage()
fw_image.subelements = [firmware.SubElement(tag_id=0x0000, data=b"fw_image")]
fw_header = firmware.OTAImageHeader(
file_version=fw_version,
image_type=0x90,
manufacturer_id=zigpy_device.manufacturer_id,
upgrade_file_id=firmware.OTAImageHeader.MAGIC_VALUE,
header_version=256,
header_length=56,
field_control=0,
stack_version=2,
header_string="This is a test header!",
image_size=56 + 2 + 4 + 8,
fw_image = OtaImageWithMetadata(
metadata=BaseOtaImageMetadata(
file_version=fw_version,
manufacturer_id=0x1234,
image_type=0x90,
changelog="This is a test firmware image!",
),
firmware=firmware.OTAImage(
header=firmware.OTAImageHeader(
upgrade_file_id=firmware.OTAImageHeader.MAGIC_VALUE,
file_version=fw_version,
image_type=0x90,
manufacturer_id=0x1234,
header_version=256,
header_length=56,
field_control=0,
stack_version=2,
header_string="This is a test header!",
image_size=56 + 2 + 4 + 8,
),
subelements=[firmware.SubElement(tag_id=0x0000, data=b"fw_image")],
),
)
fw_image.header = fw_header
fw_image.should_update = MagicMock(return_value=True)
cached_image = CachedImage(fw_image)
cluster.endpoint.device.application.ota.get_ota_image = AsyncMock(
return_value=None if file_not_found else cached_image
return_value=None if file_not_found else fw_image
)
zha_device = await zha_device_joined_restored(zigpy_device)
@ -108,18 +112,15 @@ async def setup_test_data(
return zha_device, cluster, fw_image, installed_fw_version
@pytest.mark.parametrize("initial_version_unknown", (False, True))
async def test_firmware_update_notification_from_zigpy(
hass: HomeAssistant,
zha_device_joined_restored,
zigpy_device,
initial_version_unknown,
) -> None:
"""Test ZHA update platform - firmware update notification."""
zha_device, cluster, fw_image, installed_fw_version = await setup_test_data(
zha_device_joined_restored,
zigpy_device,
skip_attribute_plugs=initial_version_unknown,
)
entity_id = find_entity_id(Platform.UPDATE, zha_device, hass)
@ -132,12 +133,16 @@ async def test_firmware_update_notification_from_zigpy(
# simulate an image available notification
await cluster._handle_query_next_image(
fw_image.header.field_control,
zha_device.manufacturer_code,
fw_image.header.image_type,
installed_fw_version,
fw_image.header.header_version,
tsn=15,
foundation.ZCLHeader.cluster(
tsn=0x12, command_id=general.Ota.ServerCommandDefs.query_next_image.id
),
general.QueryNextImageCommand(
fw_image.firmware.header.field_control,
zha_device.manufacturer_code,
fw_image.firmware.header.image_type,
installed_fw_version,
fw_image.firmware.header.header_version,
),
)
await hass.async_block_till_done()
@ -146,7 +151,9 @@ async def test_firmware_update_notification_from_zigpy(
attrs = state.attributes
assert attrs[ATTR_INSTALLED_VERSION] == f"0x{installed_fw_version:08x}"
assert not attrs[ATTR_IN_PROGRESS]
assert attrs[ATTR_LATEST_VERSION] == f"0x{fw_image.header.file_version:08x}"
assert (
attrs[ATTR_LATEST_VERSION] == f"0x{fw_image.firmware.header.file_version:08x}"
)
async def test_firmware_update_notification_from_service_call(
@ -167,35 +174,46 @@ async def test_firmware_update_notification_from_service_call(
async def _async_image_notify_side_effect(*args, **kwargs):
await cluster._handle_query_next_image(
fw_image.header.field_control,
zha_device.manufacturer_code,
fw_image.header.image_type,
installed_fw_version,
fw_image.header.header_version,
tsn=15,
foundation.ZCLHeader.cluster(
tsn=0x12, command_id=general.Ota.ServerCommandDefs.query_next_image.id
),
general.QueryNextImageCommand(
fw_image.firmware.header.field_control,
zha_device.manufacturer_code,
fw_image.firmware.header.image_type,
installed_fw_version,
fw_image.firmware.header.header_version,
),
)
await async_setup_component(hass, HA_DOMAIN, {})
cluster.image_notify = AsyncMock(side_effect=_async_image_notify_side_effect)
await hass.services.async_call(
HA_DOMAIN,
SERVICE_UPDATE_ENTITY,
service_data={ATTR_ENTITY_ID: entity_id},
blocking=True,
)
assert cluster.image_notify.await_count == 1
assert cluster.image_notify.call_args_list[0] == call(
payload_type=cluster.ImageNotifyCommand.PayloadType.QueryJitter,
query_jitter=100,
)
with patch(
"zigpy.ota.OTA.broadcast_notify", side_effect=_async_image_notify_side_effect
):
await hass.services.async_call(
HA_DOMAIN,
SERVICE_UPDATE_ENTITY,
service_data={ATTR_ENTITY_ID: entity_id},
blocking=True,
)
await hass.async_block_till_done()
state = hass.states.get(entity_id)
assert state.state == STATE_ON
attrs = state.attributes
assert attrs[ATTR_INSTALLED_VERSION] == f"0x{installed_fw_version:08x}"
assert not attrs[ATTR_IN_PROGRESS]
assert attrs[ATTR_LATEST_VERSION] == f"0x{fw_image.header.file_version:08x}"
assert cluster.endpoint.device.application.ota.broadcast_notify.await_count == 1
assert cluster.endpoint.device.application.ota.broadcast_notify.call_args_list[
0
] == call(
jitter=100,
)
await hass.async_block_till_done()
state = hass.states.get(entity_id)
assert state.state == STATE_ON
attrs = state.attributes
assert attrs[ATTR_INSTALLED_VERSION] == f"0x{installed_fw_version:08x}"
assert not attrs[ATTR_IN_PROGRESS]
assert (
attrs[ATTR_LATEST_VERSION]
== f"0x{fw_image.firmware.header.file_version:08x}"
)
def make_packet(zigpy_device, cluster, cmd_name: str, **kwargs):
@ -226,6 +244,7 @@ def make_packet(zigpy_device, cluster, cmd_name: str, **kwargs):
return ota_packet
@patch("zigpy.device.AFTER_OTA_ATTR_READ_DELAY", 0.01)
async def test_firmware_update_success(
hass: HomeAssistant, zha_device_joined_restored, zigpy_device
) -> None:
@ -234,6 +253,8 @@ async def test_firmware_update_success(
zha_device_joined_restored, zigpy_device
)
assert installed_fw_version < fw_image.firmware.header.file_version
entity_id = find_entity_id(Platform.UPDATE, zha_device, hass)
assert entity_id is not None
@ -244,12 +265,15 @@ async def test_firmware_update_success(
# simulate an image available notification
await cluster._handle_query_next_image(
fw_image.header.field_control,
zha_device.manufacturer_code,
fw_image.header.image_type,
installed_fw_version,
fw_image.header.header_version,
tsn=15,
foundation.ZCLHeader.cluster(
tsn=0x12, command_id=general.Ota.ServerCommandDefs.query_next_image.id
),
general.QueryNextImageCommand(
field_control=fw_image.firmware.header.field_control,
manufacturer_code=zha_device.manufacturer_code,
image_type=fw_image.firmware.header.image_type,
current_file_version=installed_fw_version,
),
)
await hass.async_block_till_done()
@ -258,7 +282,9 @@ async def test_firmware_update_success(
attrs = state.attributes
assert attrs[ATTR_INSTALLED_VERSION] == f"0x{installed_fw_version:08x}"
assert not attrs[ATTR_IN_PROGRESS]
assert attrs[ATTR_LATEST_VERSION] == f"0x{fw_image.header.file_version:08x}"
assert (
attrs[ATTR_LATEST_VERSION] == f"0x{fw_image.firmware.header.file_version:08x}"
)
async def endpoint_reply(cluster_id, tsn, data, command_id):
if cluster_id == general.Ota.cluster_id:
@ -270,9 +296,9 @@ async def test_firmware_update_success(
cluster,
general.Ota.ServerCommandDefs.query_next_image.name,
field_control=general.Ota.QueryNextImageCommand.FieldControl.HardwareVersion,
manufacturer_code=fw_image.header.manufacturer_id,
image_type=fw_image.header.image_type,
current_file_version=fw_image.header.file_version - 10,
manufacturer_code=fw_image.firmware.header.manufacturer_id,
image_type=fw_image.firmware.header.image_type,
current_file_version=fw_image.firmware.header.file_version - 10,
hardware_version=1,
)
)
@ -280,19 +306,19 @@ async def test_firmware_update_success(
cmd, general.Ota.ClientCommandDefs.query_next_image_response.schema
):
assert cmd.status == foundation.Status.SUCCESS
assert cmd.manufacturer_code == fw_image.header.manufacturer_id
assert cmd.image_type == fw_image.header.image_type
assert cmd.file_version == fw_image.header.file_version
assert cmd.image_size == fw_image.header.image_size
assert cmd.manufacturer_code == fw_image.firmware.header.manufacturer_id
assert cmd.image_type == fw_image.firmware.header.image_type
assert cmd.file_version == fw_image.firmware.header.file_version
assert cmd.image_size == fw_image.firmware.header.image_size
zigpy_device.packet_received(
make_packet(
zigpy_device,
cluster,
general.Ota.ServerCommandDefs.image_block.name,
field_control=general.Ota.ImageBlockCommand.FieldControl.RequestNodeAddr,
manufacturer_code=fw_image.header.manufacturer_id,
image_type=fw_image.header.image_type,
file_version=fw_image.header.file_version,
manufacturer_code=fw_image.firmware.header.manufacturer_id,
image_type=fw_image.firmware.header.image_type,
file_version=fw_image.firmware.header.file_version,
file_offset=0,
maximum_data_size=40,
request_node_addr=zigpy_device.ieee,
@ -303,20 +329,23 @@ async def test_firmware_update_success(
):
if cmd.file_offset == 0:
assert cmd.status == foundation.Status.SUCCESS
assert cmd.manufacturer_code == fw_image.header.manufacturer_id
assert cmd.image_type == fw_image.header.image_type
assert cmd.file_version == fw_image.header.file_version
assert (
cmd.manufacturer_code
== fw_image.firmware.header.manufacturer_id
)
assert cmd.image_type == fw_image.firmware.header.image_type
assert cmd.file_version == fw_image.firmware.header.file_version
assert cmd.file_offset == 0
assert cmd.image_data == fw_image.serialize()[0:40]
assert cmd.image_data == fw_image.firmware.serialize()[0:40]
zigpy_device.packet_received(
make_packet(
zigpy_device,
cluster,
general.Ota.ServerCommandDefs.image_block.name,
field_control=general.Ota.ImageBlockCommand.FieldControl.RequestNodeAddr,
manufacturer_code=fw_image.header.manufacturer_id,
image_type=fw_image.header.image_type,
file_version=fw_image.header.file_version,
manufacturer_code=fw_image.firmware.header.manufacturer_id,
image_type=fw_image.firmware.header.image_type,
file_version=fw_image.firmware.header.file_version,
file_offset=40,
maximum_data_size=40,
request_node_addr=zigpy_device.ieee,
@ -324,11 +353,14 @@ async def test_firmware_update_success(
)
elif cmd.file_offset == 40:
assert cmd.status == foundation.Status.SUCCESS
assert cmd.manufacturer_code == fw_image.header.manufacturer_id
assert cmd.image_type == fw_image.header.image_type
assert cmd.file_version == fw_image.header.file_version
assert (
cmd.manufacturer_code
== fw_image.firmware.header.manufacturer_id
)
assert cmd.image_type == fw_image.firmware.header.image_type
assert cmd.file_version == fw_image.firmware.header.file_version
assert cmd.file_offset == 40
assert cmd.image_data == fw_image.serialize()[40:70]
assert cmd.image_data == fw_image.firmware.serialize()[40:70]
# make sure the state machine gets progress reports
state = hass.states.get(entity_id)
@ -337,10 +369,10 @@ async def test_firmware_update_success(
assert (
attrs[ATTR_INSTALLED_VERSION] == f"0x{installed_fw_version:08x}"
)
assert attrs[ATTR_IN_PROGRESS] == 57
assert attrs[ATTR_IN_PROGRESS] == 58
assert (
attrs[ATTR_LATEST_VERSION]
== f"0x{fw_image.header.file_version:08x}"
== f"0x{fw_image.firmware.header.file_version:08x}"
)
zigpy_device.packet_received(
@ -349,21 +381,34 @@ async def test_firmware_update_success(
cluster,
general.Ota.ServerCommandDefs.upgrade_end.name,
status=foundation.Status.SUCCESS,
manufacturer_code=fw_image.header.manufacturer_id,
image_type=fw_image.header.image_type,
file_version=fw_image.header.file_version,
manufacturer_code=fw_image.firmware.header.manufacturer_id,
image_type=fw_image.firmware.header.image_type,
file_version=fw_image.firmware.header.file_version,
)
)
elif isinstance(
cmd, general.Ota.ClientCommandDefs.upgrade_end_response.schema
):
assert cmd.manufacturer_code == fw_image.header.manufacturer_id
assert cmd.image_type == fw_image.header.image_type
assert cmd.file_version == fw_image.header.file_version
assert cmd.manufacturer_code == fw_image.firmware.header.manufacturer_id
assert cmd.image_type == fw_image.firmware.header.image_type
assert cmd.file_version == fw_image.firmware.header.file_version
assert cmd.current_time == 0
assert cmd.upgrade_time == 0
def read_new_fw_version(*args, **kwargs):
cluster.update_attribute(
attrid=general.Ota.AttributeDefs.current_file_version.id,
value=fw_image.firmware.header.file_version,
)
return {
general.Ota.AttributeDefs.current_file_version.id: (
fw_image.firmware.header.file_version
)
}, {}
cluster.read_attributes.side_effect = read_new_fw_version
cluster.endpoint.reply = AsyncMock(side_effect=endpoint_reply)
await hass.services.async_call(
UPDATE_DOMAIN,
@ -377,10 +422,21 @@ async def test_firmware_update_success(
state = hass.states.get(entity_id)
assert state.state == STATE_OFF
attrs = state.attributes
assert attrs[ATTR_INSTALLED_VERSION] == f"0x{fw_image.header.file_version:08x}"
assert (
attrs[ATTR_INSTALLED_VERSION]
== f"0x{fw_image.firmware.header.file_version:08x}"
)
assert not attrs[ATTR_IN_PROGRESS]
assert attrs[ATTR_LATEST_VERSION] == attrs[ATTR_INSTALLED_VERSION]
# If we send a progress notification incorrectly, it won't be handled
entity = hass.data[UPDATE_DOMAIN].get_entity(entity_id)
entity._update_progress(50, 100, 0.50)
state = hass.states.get(entity_id)
assert not attrs[ATTR_IN_PROGRESS]
assert state.state == STATE_OFF
async def test_firmware_update_raises(
hass: HomeAssistant, zha_device_joined_restored, zigpy_device
@ -400,12 +456,16 @@ async def test_firmware_update_raises(
# simulate an image available notification
await cluster._handle_query_next_image(
fw_image.header.field_control,
zha_device.manufacturer_code,
fw_image.header.image_type,
installed_fw_version,
fw_image.header.header_version,
tsn=15,
foundation.ZCLHeader.cluster(
tsn=0x12, command_id=general.Ota.ServerCommandDefs.query_next_image.id
),
general.QueryNextImageCommand(
fw_image.firmware.header.field_control,
zha_device.manufacturer_code,
fw_image.firmware.header.image_type,
installed_fw_version,
fw_image.firmware.header.header_version,
),
)
await hass.async_block_till_done()
@ -414,7 +474,9 @@ async def test_firmware_update_raises(
attrs = state.attributes
assert attrs[ATTR_INSTALLED_VERSION] == f"0x{installed_fw_version:08x}"
assert not attrs[ATTR_IN_PROGRESS]
assert attrs[ATTR_LATEST_VERSION] == f"0x{fw_image.header.file_version:08x}"
assert (
attrs[ATTR_LATEST_VERSION] == f"0x{fw_image.firmware.header.file_version:08x}"
)
async def endpoint_reply(cluster_id, tsn, data, command_id):
if cluster_id == general.Ota.cluster_id:
@ -426,9 +488,9 @@ async def test_firmware_update_raises(
cluster,
general.Ota.ServerCommandDefs.query_next_image.name,
field_control=general.Ota.QueryNextImageCommand.FieldControl.HardwareVersion,
manufacturer_code=fw_image.header.manufacturer_id,
image_type=fw_image.header.image_type,
current_file_version=fw_image.header.file_version - 10,
manufacturer_code=fw_image.firmware.header.manufacturer_id,
image_type=fw_image.firmware.header.image_type,
current_file_version=fw_image.firmware.header.file_version - 10,
hardware_version=1,
)
)
@ -436,10 +498,10 @@ async def test_firmware_update_raises(
cmd, general.Ota.ClientCommandDefs.query_next_image_response.schema
):
assert cmd.status == foundation.Status.SUCCESS
assert cmd.manufacturer_code == fw_image.header.manufacturer_id
assert cmd.image_type == fw_image.header.image_type
assert cmd.file_version == fw_image.header.file_version
assert cmd.image_size == fw_image.header.image_size
assert cmd.manufacturer_code == fw_image.firmware.header.manufacturer_id
assert cmd.image_type == fw_image.firmware.header.image_type
assert cmd.file_version == fw_image.firmware.header.file_version
assert cmd.image_size == fw_image.firmware.header.image_size
raise DeliveryError("failed to deliver")
cluster.endpoint.reply = AsyncMock(side_effect=endpoint_reply)
@ -467,29 +529,10 @@ async def test_firmware_update_raises(
)
async def test_firmware_update_restore_data(
async def test_firmware_update_no_longer_compatible(
hass: HomeAssistant, zha_device_joined_restored, zigpy_device
) -> None:
"""Test ZHA update platform - restore data."""
fw_version = 0x12345678
installed_fw_version = fw_version - 10
mock_restore_cache_with_extra_data(
hass,
[
(
State(
"update.fakemanufacturer_fakemodel_firmware",
STATE_ON,
{
ATTR_INSTALLED_VERSION: f"0x{installed_fw_version:08x}",
ATTR_LATEST_VERSION: f"0x{fw_version:08x}",
ATTR_SKIPPED_VERSION: None,
},
),
{"image_type": 0x90},
)
],
)
"""Test ZHA update platform - firmware update is no longer valid."""
zha_device, cluster, fw_image, installed_fw_version = await setup_test_data(
zha_device_joined_restored, zigpy_device
)
@ -500,94 +543,67 @@ async def test_firmware_update_restore_data(
# allow traffic to flow through the gateway and device
await async_enable_traffic(hass, [zha_device])
assert hass.states.get(entity_id).state == STATE_OFF
# simulate an image available notification
await cluster._handle_query_next_image(
foundation.ZCLHeader.cluster(
tsn=0x12, command_id=general.Ota.ServerCommandDefs.query_next_image.id
),
general.QueryNextImageCommand(
fw_image.firmware.header.field_control,
zha_device.manufacturer_code,
fw_image.firmware.header.image_type,
installed_fw_version,
fw_image.firmware.header.header_version,
),
)
await hass.async_block_till_done()
state = hass.states.get(entity_id)
assert state.state == STATE_ON
attrs = state.attributes
assert attrs[ATTR_INSTALLED_VERSION] == f"0x{installed_fw_version:08x}"
assert not attrs[ATTR_IN_PROGRESS]
assert attrs[ATTR_LATEST_VERSION] == f"0x{fw_image.header.file_version:08x}"
async def test_firmware_update_restore_file_not_found(
hass: HomeAssistant, zha_device_joined_restored, zigpy_device
) -> None:
"""Test ZHA update platform - restore data - file not found."""
fw_version = 0x12345678
installed_fw_version = fw_version - 10
mock_restore_cache_with_extra_data(
hass,
[
(
State(
"update.fakemanufacturer_fakemodel_firmware",
STATE_ON,
{
ATTR_INSTALLED_VERSION: f"0x{installed_fw_version:08x}",
ATTR_LATEST_VERSION: f"0x{fw_version:08x}",
ATTR_SKIPPED_VERSION: None,
},
),
{"image_type": 0x90},
)
],
)
zha_device, cluster, fw_image, installed_fw_version = await setup_test_data(
zha_device_joined_restored, zigpy_device, file_not_found=True
assert (
attrs[ATTR_LATEST_VERSION] == f"0x{fw_image.firmware.header.file_version:08x}"
)
entity_id = find_entity_id(Platform.UPDATE, zha_device, hass)
assert entity_id is not None
new_version = 0x99999999
# allow traffic to flow through the gateway and device
await async_enable_traffic(hass, [zha_device])
async def endpoint_reply(cluster_id, tsn, data, command_id):
if cluster_id == general.Ota.cluster_id:
hdr, cmd = cluster.deserialize(data)
if isinstance(cmd, general.Ota.ImageNotifyCommand):
zigpy_device.packet_received(
make_packet(
zigpy_device,
cluster,
general.Ota.ServerCommandDefs.query_next_image.name,
field_control=general.Ota.QueryNextImageCommand.FieldControl.HardwareVersion,
manufacturer_code=fw_image.firmware.header.manufacturer_id,
image_type=fw_image.firmware.header.image_type,
# The device reports that it is no longer compatible!
current_file_version=new_version,
hardware_version=1,
)
)
cluster.endpoint.reply = AsyncMock(side_effect=endpoint_reply)
with pytest.raises(HomeAssistantError):
await hass.services.async_call(
UPDATE_DOMAIN,
SERVICE_INSTALL,
{
ATTR_ENTITY_ID: entity_id,
},
blocking=True,
)
# We updated the currently installed firmware version, as it is no longer valid
state = hass.states.get(entity_id)
assert state.state == STATE_OFF
attrs = state.attributes
assert attrs[ATTR_INSTALLED_VERSION] == f"0x{installed_fw_version:08x}"
assert attrs[ATTR_INSTALLED_VERSION] == f"0x{new_version:08x}"
assert not attrs[ATTR_IN_PROGRESS]
assert attrs[ATTR_LATEST_VERSION] == f"0x{installed_fw_version:08x}"
async def test_firmware_update_restore_version_from_state_machine(
hass: HomeAssistant, zha_device_joined_restored, zigpy_device
) -> None:
"""Test ZHA update platform - restore data - file not found."""
fw_version = 0x12345678
installed_fw_version = fw_version - 10
mock_restore_cache_with_extra_data(
hass,
[
(
State(
"update.fakemanufacturer_fakemodel_firmware",
STATE_ON,
{
ATTR_INSTALLED_VERSION: f"0x{installed_fw_version:08x}",
ATTR_LATEST_VERSION: f"0x{fw_version:08x}",
ATTR_SKIPPED_VERSION: None,
},
),
{"image_type": 0x90},
)
],
)
zha_device, cluster, fw_image, installed_fw_version = await setup_test_data(
zha_device_joined_restored,
zigpy_device,
skip_attribute_plugs=True,
file_not_found=True,
)
entity_id = find_entity_id(Platform.UPDATE, zha_device, hass)
assert entity_id is not None
# allow traffic to flow through the gateway and device
await async_enable_traffic(hass, [zha_device])
state = hass.states.get(entity_id)
assert state.state == STATE_OFF
attrs = state.attributes
assert attrs[ATTR_INSTALLED_VERSION] == f"0x{installed_fw_version:08x}"
assert not attrs[ATTR_IN_PROGRESS]
assert attrs[ATTR_LATEST_VERSION] == f"0x{installed_fw_version:08x}"
assert attrs[ATTR_LATEST_VERSION] == f"0x{new_version:08x}"