Use new zigpy OTA providers for ZHA (#111159)

* Use `None` instead of `"unknown"` when the current version is unknown

* Only use the current file version from the OTA notification

* Use `sw_version`, if available, and update `current_file_version`

* Assume the current version is the latest version

* Fix lint errors

* Use `image` instead of `firmware`

* Include a changelog if updates expose it

* Clear latest firmware only after updating the installed version

* Bump minimum zigpy version to 0.63.0

* Create a data update coordinator to consolidate updates

* Fix overridden `async_update`

* Fix most unit tests

* Simplify `test_devices` to fix current tests

* Use a dict comprehension for creating mocked entities

* Fix unit tests (thanks @dmulcahey!)

* Update the currently installed version on cluster attribute update

* Drop `PARALLEL_UPDATES` now that we use an update coordinator

* Drop `_reset_progress`, it is already handled by the update component

* Do not update the progress if we are not supposed to be updating

* Ignore latest version (e.g. if device attrs changed) if zigpy rejects it

* Clean up handling of command id in `Ota.cluster_command`

* Start progress at 1%: 0 and False are considered equal and are filtered!

Use `ceil` instead of remapping 1-100

* The installed version will be auto-updated when the upgrade succeeds

* Avoid 1 as well, it collides with `True`

* Bump zigpy to (unreleased) 0.63.2

* Fix unit tests

* Fix existing unit tests

Send both event types

Globally enable sending both event types

* Remove unnecessary branches

* Test ignoring invalid progress callbacks

* Test updating a device with a no longer compatible firmware
This commit is contained in:
puddly 2024-02-28 14:38:04 -05:00 committed by GitHub
parent 4895f92551
commit 4ec75d6ca7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 388 additions and 423 deletions

View File

@ -627,8 +627,9 @@ class ClientClusterHandler(ClusterHandler):
"""ClusterHandler for Zigbee client (output) clusters.""" """ClusterHandler for Zigbee client (output) clusters."""
@callback @callback
def attribute_updated(self, attrid: int, value: Any, _: Any) -> None: def attribute_updated(self, attrid: int, value: Any, timestamp: Any) -> None:
"""Handle an attribute updated on this cluster.""" """Handle an attribute updated on this cluster."""
super().attribute_updated(attrid, value, timestamp)
try: try:
attr_name = self._cluster.attributes[attrid].name attr_name = self._cluster.attributes[attrid].name

View File

@ -56,7 +56,6 @@ from ..const import (
SIGNAL_MOVE_LEVEL, SIGNAL_MOVE_LEVEL,
SIGNAL_SET_LEVEL, SIGNAL_SET_LEVEL,
SIGNAL_UPDATE_DEVICE, SIGNAL_UPDATE_DEVICE,
UNKNOWN as ZHA_UNKNOWN,
) )
from . import ( from . import (
AttrReportConfig, AttrReportConfig,
@ -538,14 +537,9 @@ class OtaClusterHandler(ClusterHandler):
} }
@property @property
def current_file_version(self) -> str: def current_file_version(self) -> int | None:
"""Return cached value of current_file_version attribute.""" """Return cached value of current_file_version attribute."""
current_file_version = self.cluster.get( return self.cluster.get(Ota.AttributeDefs.current_file_version.name)
Ota.AttributeDefs.current_file_version.name
)
if current_file_version is not None:
return f"0x{int(current_file_version):08x}"
return ZHA_UNKNOWN
@registries.CLIENT_CLUSTER_HANDLER_REGISTRY.register(Ota.cluster_id) @registries.CLIENT_CLUSTER_HANDLER_REGISTRY.register(Ota.cluster_id)
@ -559,36 +553,31 @@ class OtaClientClusterHandler(ClientClusterHandler):
} }
@property @property
def current_file_version(self) -> str: def current_file_version(self) -> int | None:
"""Return cached value of current_file_version attribute.""" """Return cached value of current_file_version attribute."""
current_file_version = self.cluster.get( return self.cluster.get(Ota.AttributeDefs.current_file_version.name)
Ota.AttributeDefs.current_file_version.name
)
if current_file_version is not None:
return f"0x{int(current_file_version):08x}"
return ZHA_UNKNOWN
@callback @callback
def cluster_command( def cluster_command(
self, tsn: int, command_id: int, args: list[Any] | None self, tsn: int, command_id: int, args: list[Any] | None
) -> None: ) -> None:
"""Handle OTA commands.""" """Handle OTA commands."""
if command_id in self.cluster.server_commands: if command_id not in self.cluster.server_commands:
cmd_name = self.cluster.server_commands[command_id].name return
else:
cmd_name = command_id
signal_id = self._endpoint.unique_id.split("-")[0] signal_id = self._endpoint.unique_id.split("-")[0]
cmd_name = self.cluster.server_commands[command_id].name
if cmd_name == Ota.ServerCommandDefs.query_next_image.name: if cmd_name == Ota.ServerCommandDefs.query_next_image.name:
assert args assert args
self.async_send_signal(SIGNAL_UPDATE_DEVICE.format(signal_id), args[3])
async def async_check_for_update(self): current_file_version = args[3]
"""Check for firmware availability by issuing an image notify command.""" self.cluster.update_attribute(
await self.cluster.image_notify( Ota.AttributeDefs.current_file_version.id, current_file_version
payload_type=(self.cluster.ImageNotifyCommand.PayloadType.QueryJitter), )
query_jitter=100, self.async_send_signal(
) SIGNAL_UPDATE_DEVICE.format(signal_id), current_file_version
)
@registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.register(Partition.cluster_id) @registries.ZIGBEE_CLUSTER_HANDLER_REGISTRY.register(Partition.cluster_id)

View File

@ -75,11 +75,12 @@ async def async_add_entities(
tuple[str, ZHADevice, list[ClusterHandler]], tuple[str, ZHADevice, list[ClusterHandler]],
] ]
], ],
**kwargs,
) -> None: ) -> None:
"""Add entities helper.""" """Add entities helper."""
if not entities: if not entities:
return return
to_add = [ent_cls.create_entity(*args) for ent_cls, args in entities] to_add = [ent_cls.create_entity(*args, **kwargs) for ent_cls, args in entities]
entities_to_add = [entity for entity in to_add if entity is not None] entities_to_add = [entity for entity in to_add if entity is not None]
_async_add_entities(entities_to_add, update_before_add=False) _async_add_entities(entities_to_add, update_before_add=False)
entities.clear() entities.clear()

View File

@ -27,7 +27,7 @@
"pyserial-asyncio==0.6", "pyserial-asyncio==0.6",
"zha-quirks==0.0.112", "zha-quirks==0.0.112",
"zigpy-deconz==0.23.1", "zigpy-deconz==0.23.1",
"zigpy==0.62.3", "zigpy==0.63.2",
"zigpy-xbee==0.20.1", "zigpy-xbee==0.20.1",
"zigpy-zigate==0.12.0", "zigpy-zigate==0.12.0",
"zigpy-znp==0.12.1", "zigpy-znp==0.12.1",

View File

@ -1,17 +1,16 @@
"""Representation of ZHA updates.""" """Representation of ZHA updates."""
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass
import functools import functools
import logging
import math
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
from zigpy.ota.image import BaseOTAImage from zigpy.ota import OtaImageWithMetadata
from zigpy.types import uint16_t from zigpy.zcl.clusters.general import Ota
from zigpy.zcl.foundation import Status from zigpy.zcl.foundation import Status
from homeassistant.components.update import ( from homeassistant.components.update import (
ATTR_INSTALLED_VERSION,
ATTR_LATEST_VERSION,
UpdateDeviceClass, UpdateDeviceClass,
UpdateEntity, UpdateEntity,
UpdateEntityFeature, UpdateEntityFeature,
@ -22,36 +21,29 @@ from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.dispatcher import async_dispatcher_connect from homeassistant.helpers.dispatcher import async_dispatcher_connect
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.restore_state import ExtraStoredData from homeassistant.helpers.update_coordinator import (
CoordinatorEntity,
DataUpdateCoordinator,
)
from .core import discovery from .core import discovery
from .core.const import CLUSTER_HANDLER_OTA, SIGNAL_ADD_ENTITIES, UNKNOWN from .core.const import CLUSTER_HANDLER_OTA, SIGNAL_ADD_ENTITIES, SIGNAL_ATTR_UPDATED
from .core.helpers import get_zha_data from .core.helpers import get_zha_data, get_zha_gateway
from .core.registries import ZHA_ENTITIES from .core.registries import ZHA_ENTITIES
from .entity import ZhaEntity from .entity import ZhaEntity
if TYPE_CHECKING: if TYPE_CHECKING:
from zigpy.application import ControllerApplication
from .core.cluster_handlers import ClusterHandler from .core.cluster_handlers import ClusterHandler
from .core.device import ZHADevice from .core.device import ZHADevice
_LOGGER = logging.getLogger(__name__)
CONFIG_DIAGNOSTIC_MATCH = functools.partial( CONFIG_DIAGNOSTIC_MATCH = functools.partial(
ZHA_ENTITIES.config_diagnostic_match, Platform.UPDATE ZHA_ENTITIES.config_diagnostic_match, Platform.UPDATE
) )
# don't let homeassistant check for updates button hammer the zigbee network
PARALLEL_UPDATES = 1
@dataclass
class ZHAFirmwareUpdateExtraStoredData(ExtraStoredData):
"""Extra stored data for ZHA firmware update entity."""
image_type: uint16_t | None
def as_dict(self) -> dict[str, Any]:
"""Return a dict representation of the extra data."""
return {"image_type": self.image_type}
async def async_setup_entry( async def async_setup_entry(
hass: HomeAssistant, hass: HomeAssistant,
@ -62,18 +54,46 @@ async def async_setup_entry(
zha_data = get_zha_data(hass) zha_data = get_zha_data(hass)
entities_to_create = zha_data.platforms[Platform.UPDATE] entities_to_create = zha_data.platforms[Platform.UPDATE]
coordinator = ZHAFirmwareUpdateCoordinator(
hass, get_zha_gateway(hass).application_controller
)
unsub = async_dispatcher_connect( unsub = async_dispatcher_connect(
hass, hass,
SIGNAL_ADD_ENTITIES, SIGNAL_ADD_ENTITIES,
functools.partial( functools.partial(
discovery.async_add_entities, async_add_entities, entities_to_create discovery.async_add_entities,
async_add_entities,
entities_to_create,
coordinator=coordinator,
), ),
) )
config_entry.async_on_unload(unsub) config_entry.async_on_unload(unsub)
class ZHAFirmwareUpdateCoordinator(DataUpdateCoordinator): # pylint: disable=hass-enforce-coordinator-module
"""Firmware update coordinator that broadcasts updates network-wide."""
def __init__(
self, hass: HomeAssistant, controller_application: ControllerApplication
) -> None:
"""Initialize the coordinator."""
super().__init__(
hass,
_LOGGER,
name="ZHA firmware update coordinator",
update_method=self.async_update_data,
)
self.controller_application = controller_application
async def async_update_data(self) -> None:
"""Fetch the latest firmware update data."""
# Broadcast to all devices
await self.controller_application.ota.broadcast_notify(jitter=100)
@CONFIG_DIAGNOSTIC_MATCH(cluster_handler_names=CLUSTER_HANDLER_OTA) @CONFIG_DIAGNOSTIC_MATCH(cluster_handler_names=CLUSTER_HANDLER_OTA)
class ZHAFirmwareUpdateEntity(ZhaEntity, UpdateEntity): class ZHAFirmwareUpdateEntity(ZhaEntity, CoordinatorEntity, UpdateEntity):
"""Representation of a ZHA firmware update entity.""" """Representation of a ZHA firmware update entity."""
_unique_id_suffix = "firmware_update" _unique_id_suffix = "firmware_update"
@ -90,147 +110,114 @@ class ZHAFirmwareUpdateEntity(ZhaEntity, UpdateEntity):
unique_id: str, unique_id: str,
zha_device: ZHADevice, zha_device: ZHADevice,
channels: list[ClusterHandler], channels: list[ClusterHandler],
coordinator: ZHAFirmwareUpdateCoordinator,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Initialize the ZHA update entity.""" """Initialize the ZHA update entity."""
super().__init__(unique_id, zha_device, channels, **kwargs) super().__init__(unique_id, zha_device, channels, **kwargs)
CoordinatorEntity.__init__(self, coordinator)
self._ota_cluster_handler: ClusterHandler = self.cluster_handlers[ self._ota_cluster_handler: ClusterHandler = self.cluster_handlers[
CLUSTER_HANDLER_OTA CLUSTER_HANDLER_OTA
] ]
self._attr_installed_version: str = self.determine_installed_version() self._attr_installed_version: str | None = self._get_cluster_version()
self._image_type: uint16_t | None = None self._attr_latest_version = self._attr_installed_version
self._latest_version_firmware: BaseOTAImage | None = None self._latest_firmware: OtaImageWithMetadata | None = None
self._result = None
def _get_cluster_version(self) -> str | None:
"""Synchronize current file version with the cluster."""
device = self._ota_cluster_handler._endpoint.device # pylint: disable=protected-access
if self._ota_cluster_handler.current_file_version is not None:
return f"0x{self._ota_cluster_handler.current_file_version:08x}"
if device.sw_version is not None:
return device.sw_version
return None
@callback @callback
def determine_installed_version(self) -> str: def attribute_updated(self, attrid: int, name: str, value: Any) -> None:
"""Determine the currently installed firmware version.""" """Handle attribute updates on the OTA cluster."""
currently_installed_version = self._ota_cluster_handler.current_file_version if attrid == Ota.AttributeDefs.current_file_version.id:
version_from_dr = self.zha_device.sw_version self._attr_installed_version = f"0x{value:08x}"
if currently_installed_version == UNKNOWN and version_from_dr: self.async_write_ha_state()
currently_installed_version = version_from_dr
return currently_installed_version
@property
def extra_restore_state_data(self) -> ZHAFirmwareUpdateExtraStoredData:
"""Return ZHA firmware update specific state data to be restored."""
return ZHAFirmwareUpdateExtraStoredData(self._image_type)
@callback @callback
def device_ota_update_available(self, image: BaseOTAImage) -> None: def device_ota_update_available(
self, image: OtaImageWithMetadata, current_file_version: int
) -> None:
"""Handle ota update available signal from Zigpy.""" """Handle ota update available signal from Zigpy."""
self._latest_version_firmware = image self._latest_firmware = image
self._attr_latest_version = f"0x{image.header.file_version:08x}" self._attr_latest_version = f"0x{image.version:08x}"
self._image_type = image.header.image_type self._attr_installed_version = f"0x{current_file_version:08x}"
self._attr_installed_version = self.determine_installed_version()
if image.metadata.changelog:
self._attr_release_summary = image.metadata.changelog
self.async_write_ha_state() self.async_write_ha_state()
@callback @callback
def _update_progress(self, current: int, total: int, progress: float) -> None: def _update_progress(self, current: int, total: int, progress: float) -> None:
"""Update install progress on event.""" """Update install progress on event."""
assert self._latest_version_firmware # If we are not supposed to be updating, do nothing
self._attr_in_progress = int(progress) if self._attr_in_progress is False:
return
# Remap progress to 2-100 to avoid 0 and 1
self._attr_in_progress = int(math.ceil(2 + 98 * progress / 100))
self.async_write_ha_state() self.async_write_ha_state()
@callback
def _reset_progress(self, write_state: bool = True) -> None:
"""Reset update install progress."""
self._result = None
self._attr_in_progress = False
if write_state:
self.async_write_ha_state()
async def async_update(self) -> None:
"""Handle the update entity service call to manually check for available firmware updates."""
await super().async_update()
# check for updates in the HA settings menu can invoke this so we need to check if the device
# is mains powered so we don't get a ton of errors in the logs from sleepy devices.
if self.zha_device.available and self.zha_device.is_mains_powered:
await self._ota_cluster_handler.async_check_for_update()
async def async_install( async def async_install(
self, version: str | None, backup: bool, **kwargs: Any self, version: str | None, backup: bool, **kwargs: Any
) -> None: ) -> None:
"""Install an update.""" """Install an update."""
firmware = self._latest_version_firmware assert self._latest_firmware is not None
assert firmware
self._reset_progress(False) # Set the progress to an indeterminate state
self._attr_in_progress = True self._attr_in_progress = True
self.async_write_ha_state() self.async_write_ha_state()
try: try:
self._result = await self.zha_device.device.update_firmware( result = await self.zha_device.device.update_firmware(
self._latest_version_firmware, image=self._latest_firmware,
self._update_progress, progress_callback=self._update_progress,
) )
except Exception as ex: except Exception as ex:
self._reset_progress() raise HomeAssistantError(f"Update was not successful: {ex}") from ex
raise HomeAssistantError(ex) from ex
assert self._result is not None # If we tried to install firmware that is no longer compatible with the device,
# bail out
if result == Status.NO_IMAGE_AVAILABLE:
self._attr_latest_version = self._attr_installed_version
self.async_write_ha_state()
# If the update was not successful, we should throw an error to let the user know # If the update finished but was not successful, we should also throw an error
if self._result != Status.SUCCESS: if result != Status.SUCCESS:
# save result since reset_progress will clear it raise HomeAssistantError(f"Update was not successful: {result}")
results = self._result
self._reset_progress()
raise HomeAssistantError(f"Update was not successful - result: {results}")
# If we get here, all files were installed successfully # Clear the state
self._attr_installed_version = ( self._latest_firmware = None
self._attr_latest_version self._attr_in_progress = False
) = f"0x{firmware.header.file_version:08x}" self.async_write_ha_state()
self._latest_version_firmware = None
self._reset_progress()
async def async_added_to_hass(self) -> None: async def async_added_to_hass(self) -> None:
"""Call when entity is added.""" """Call when entity is added."""
await super().async_added_to_hass() await super().async_added_to_hass()
last_state = await self.async_get_last_state()
# If we have a complete previous state, use that to set the installed version
if (
last_state
and self._attr_installed_version == UNKNOWN
and (installed_version := last_state.attributes.get(ATTR_INSTALLED_VERSION))
):
self._attr_installed_version = installed_version
# If we have a complete previous state, use that to set the latest version
if (
last_state
and (latest_version := last_state.attributes.get(ATTR_LATEST_VERSION))
is not None
and latest_version != UNKNOWN
):
self._attr_latest_version = latest_version
# If we have no state or latest version to restore, or the latest version is
# the same as the installed version, we can set the latest
# version to installed so that the entity starts as off.
elif (
not last_state
or not latest_version
or latest_version == self._attr_installed_version
):
self._attr_latest_version = self._attr_installed_version
if self._attr_latest_version != self._attr_installed_version and (
extra_data := await self.async_get_last_extra_data()
):
self._image_type = extra_data.as_dict()["image_type"]
if self._image_type:
self._latest_version_firmware = (
await self.zha_device.device.application.ota.get_ota_image(
self.zha_device.manufacturer_code, self._image_type
)
)
# if we can't locate an image but we have a latest version that differs
# we should set the latest version to the installed version to avoid
# confusion and errors
if not self._latest_version_firmware:
self._attr_latest_version = self._attr_installed_version
# OTA events are sent by the device
self.zha_device.device.add_listener(self) self.zha_device.device.add_listener(self)
self.async_accept_signal(
self._ota_cluster_handler, SIGNAL_ATTR_UPDATED, self.attribute_updated
)
async def async_will_remove_from_hass(self) -> None: async def async_will_remove_from_hass(self) -> None:
"""Call when entity will be removed.""" """Call when entity will be removed."""
await super().async_will_remove_from_hass() await super().async_will_remove_from_hass()
self._reset_progress(False) self._attr_in_progress = False
async def async_update(self) -> None:
"""Update the entity."""
await CoordinatorEntity.async_update(self)
await super().async_update()

View File

@ -2950,7 +2950,7 @@ zigpy-zigate==0.12.0
zigpy-znp==0.12.1 zigpy-znp==0.12.1
# homeassistant.components.zha # homeassistant.components.zha
zigpy==0.62.3 zigpy==0.63.2
# homeassistant.components.zoneminder # homeassistant.components.zoneminder
zm-py==0.5.4 zm-py==0.5.4

View File

@ -2270,7 +2270,7 @@ zigpy-zigate==0.12.0
zigpy-znp==0.12.1 zigpy-znp==0.12.1
# homeassistant.components.zha # homeassistant.components.zha
zigpy==0.62.3 zigpy==0.63.2
# homeassistant.components.zwave_js # homeassistant.components.zwave_js
zwave-js-server-python==0.55.3 zwave-js-server-python==0.55.3

View File

@ -156,12 +156,7 @@ async def zigpy_app_controller():
zigpy.config.CONF_NWK_BACKUP_ENABLED: False, zigpy.config.CONF_NWK_BACKUP_ENABLED: False,
zigpy.config.CONF_TOPO_SCAN_ENABLED: False, zigpy.config.CONF_TOPO_SCAN_ENABLED: False,
zigpy.config.CONF_OTA: { zigpy.config.CONF_OTA: {
zigpy.config.CONF_OTA_IKEA: False, zigpy.config.CONF_OTA_ENABLED: False,
zigpy.config.CONF_OTA_INOVELLI: False,
zigpy.config.CONF_OTA_LEDVANCE: False,
zigpy.config.CONF_OTA_SALUS: False,
zigpy.config.CONF_OTA_SONOFF: False,
zigpy.config.CONF_OTA_THIRDREALITY: False,
}, },
} }
) )

View File

@ -24,7 +24,8 @@ from homeassistant.components.zha.core.helpers import get_zha_gateway
import homeassistant.components.zha.core.registries as zha_regs import homeassistant.components.zha.core.registries as zha_regs
from homeassistant.const import Platform from homeassistant.const import Platform
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
import homeassistant.helpers.entity_registry as er from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.entity_platform import EntityPlatform
from .conftest import SIG_EP_INPUT, SIG_EP_OUTPUT, SIG_EP_PROFILE, SIG_EP_TYPE from .conftest import SIG_EP_INPUT, SIG_EP_OUTPUT, SIG_EP_PROFILE, SIG_EP_TYPE
from .zha_devices_list import ( from .zha_devices_list import (
@ -80,8 +81,6 @@ async def test_devices(
zha_device_joined_restored, zha_device_joined_restored,
) -> None: ) -> None:
"""Test device discovery.""" """Test device discovery."""
entity_registry = er.async_get(hass_disable_services)
zigpy_device = zigpy_device_mock( zigpy_device = zigpy_device_mock(
endpoints=device[SIG_ENDPOINTS], endpoints=device[SIG_ENDPOINTS],
ieee="00:11:22:33:44:55:66:77", ieee="00:11:22:33:44:55:66:77",
@ -96,14 +95,13 @@ async def test_devices(
if cluster_identify: if cluster_identify:
cluster_identify.request.reset_mock() cluster_identify.request.reset_mock()
orig_new_entity = Endpoint.async_new_entity with patch(
_dispatch = mock.MagicMock(wraps=orig_new_entity) "homeassistant.helpers.entity_platform.EntityPlatform._async_schedule_add_entities_for_entry",
try: side_effect=EntityPlatform._async_schedule_add_entities_for_entry,
Endpoint.async_new_entity = lambda *a, **kw: _dispatch(*a, **kw) autospec=True,
) as mock_add_entities:
zha_dev = await zha_device_joined_restored(zigpy_device) zha_dev = await zha_device_joined_restored(zigpy_device)
await hass_disable_services.async_block_till_done() await hass_disable_services.async_block_till_done()
finally:
Endpoint.async_new_entity = orig_new_entity
if cluster_identify: if cluster_identify:
# We only identify on join # We only identify on join
@ -136,60 +134,38 @@ async def test_devices(
for ch in endpoint.client_cluster_handlers.values() for ch in endpoint.client_cluster_handlers.values()
} }
assert event_cluster_handlers == set(device[DEV_SIG_EVT_CLUSTER_HANDLERS]) assert event_cluster_handlers == set(device[DEV_SIG_EVT_CLUSTER_HANDLERS])
# we need to probe the class create entity factory so we need to reset this to get accurate results
zha_regs.ZHA_ENTITIES.clean_up()
# build a dict of entity_class -> (platform, unique_id, cluster_handlers) tuple
ha_ent_info = {}
created_entity_count = 0
for call in _dispatch.call_args_list:
_, platform, entity_cls, unique_id, cluster_handlers = call[0]
# the factory can return None. We filter these out to get an accurate created entity count
response = entity_cls.create_entity(unique_id, zha_dev, cluster_handlers)
if response and not contains_ignored_suffix(response.unique_id):
created_entity_count += 1
unique_id_head = UNIQUE_ID_HD.match(unique_id).group(
0
) # ieee + endpoint_id
ha_ent_info[(unique_id_head, entity_cls.__name__)] = (
platform,
unique_id,
cluster_handlers,
)
for comp_id, ent_info in device[DEV_SIG_ENT_MAP].items(): # Keep track of unhandled entities: they should always be ones we explicitly ignore
platform, unique_id = comp_id created_entities = {
entity.entity_id: entity
for mock_call in mock_add_entities.mock_calls
for entity in mock_call.args[1]
}
unhandled_entities = set(created_entities.keys())
entity_registry = er.async_get(hass_disable_services)
for (platform, unique_id), ent_info in device[DEV_SIG_ENT_MAP].items():
no_tail_id = NO_TAIL_ID.sub("", ent_info[DEV_SIG_ENT_MAP_ID]) no_tail_id = NO_TAIL_ID.sub("", ent_info[DEV_SIG_ENT_MAP_ID])
ha_entity_id = entity_registry.async_get_entity_id(platform, "zha", unique_id) ha_entity_id = entity_registry.async_get_entity_id(platform, "zha", unique_id)
assert ha_entity_id is not None assert ha_entity_id is not None
assert ha_entity_id.startswith(no_tail_id) assert ha_entity_id.startswith(no_tail_id)
test_ent_class = ent_info[DEV_SIG_ENT_MAP_CLASS] entity = created_entities[ha_entity_id]
test_unique_id_head = UNIQUE_ID_HD.match(unique_id).group(0) unhandled_entities.remove(ha_entity_id)
assert (test_unique_id_head, test_ent_class) in ha_ent_info
ha_comp, ha_unique_id, ha_cluster_handlers = ha_ent_info[ assert entity.platform.domain == platform
(test_unique_id_head, test_ent_class) assert type(entity).__name__ == ent_info[DEV_SIG_ENT_MAP_CLASS]
]
assert platform is ha_comp.value
# unique_id used for discover is the same for "multi entities" # unique_id used for discover is the same for "multi entities"
assert unique_id.startswith(ha_unique_id) assert unique_id == entity.unique_id
assert {ch.name for ch in ha_cluster_handlers} == set( assert {ch.name for ch in entity.cluster_handlers.values()} == set(
ent_info[DEV_SIG_CLUSTER_HANDLERS] ent_info[DEV_SIG_CLUSTER_HANDLERS]
) )
assert created_entity_count == len(device[DEV_SIG_ENT_MAP]) # All unhandled entities should be ones we explicitly ignore
for entity_id in unhandled_entities:
entity_ids = hass_disable_services.states.async_entity_ids() domain = entity_id.split(".")[0]
await hass_disable_services.async_block_till_done() assert domain in zha_const.PLATFORMS
assert contains_ignored_suffix(entity_id)
zha_entity_ids = {
ent
for ent in entity_ids
if not contains_ignored_suffix(ent) and ent.split(".")[0] in zha_const.PLATFORMS
}
assert zha_entity_ids == {
e[DEV_SIG_ENT_MAP_ID] for e in device[DEV_SIG_ENT_MAP].values()
}
def _get_first_identify_cluster(zigpy_device): def _get_first_identify_cluster(zigpy_device):

View File

@ -1,10 +1,11 @@
"""Test ZHA firmware updates.""" """Test ZHA firmware updates."""
from unittest.mock import AsyncMock, MagicMock, call, patch from unittest.mock import AsyncMock, call, patch
import pytest import pytest
from zigpy.exceptions import DeliveryError from zigpy.exceptions import DeliveryError
from zigpy.ota import CachedImage from zigpy.ota import OtaImageWithMetadata
import zigpy.ota.image as firmware import zigpy.ota.image as firmware
from zigpy.ota.providers import BaseOtaImageMetadata
import zigpy.profiles.zha as zha import zigpy.profiles.zha as zha
import zigpy.types as t import zigpy.types as t
import zigpy.zcl.clusters.general as general import zigpy.zcl.clusters.general as general
@ -21,17 +22,14 @@ from homeassistant.components.update import (
DOMAIN as UPDATE_DOMAIN, DOMAIN as UPDATE_DOMAIN,
SERVICE_INSTALL, SERVICE_INSTALL,
) )
from homeassistant.components.update.const import ATTR_SKIPPED_VERSION
from homeassistant.const import ATTR_ENTITY_ID, STATE_OFF, STATE_ON, Platform from homeassistant.const import ATTR_ENTITY_ID, STATE_OFF, STATE_ON, Platform
from homeassistant.core import HomeAssistant, State from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from .common import async_enable_traffic, find_entity_id, update_attribute_cache from .common import async_enable_traffic, find_entity_id, update_attribute_cache
from .conftest import SIG_EP_INPUT, SIG_EP_OUTPUT, SIG_EP_TYPE from .conftest import SIG_EP_INPUT, SIG_EP_OUTPUT, SIG_EP_TYPE
from tests.common import mock_restore_cache_with_extra_data
@pytest.fixture(autouse=True) @pytest.fixture(autouse=True)
def update_platform_only(): def update_platform_only():
@ -80,26 +78,32 @@ async def setup_test_data(
update_attribute_cache(cluster) update_attribute_cache(cluster)
# set up firmware image # set up firmware image
fw_image = firmware.OTAImage() fw_image = OtaImageWithMetadata(
fw_image.subelements = [firmware.SubElement(tag_id=0x0000, data=b"fw_image")] metadata=BaseOtaImageMetadata(
fw_header = firmware.OTAImageHeader( file_version=fw_version,
file_version=fw_version, manufacturer_id=0x1234,
image_type=0x90, image_type=0x90,
manufacturer_id=zigpy_device.manufacturer_id, changelog="This is a test firmware image!",
upgrade_file_id=firmware.OTAImageHeader.MAGIC_VALUE, ),
header_version=256, firmware=firmware.OTAImage(
header_length=56, header=firmware.OTAImageHeader(
field_control=0, upgrade_file_id=firmware.OTAImageHeader.MAGIC_VALUE,
stack_version=2, file_version=fw_version,
header_string="This is a test header!", image_type=0x90,
image_size=56 + 2 + 4 + 8, manufacturer_id=0x1234,
header_version=256,
header_length=56,
field_control=0,
stack_version=2,
header_string="This is a test header!",
image_size=56 + 2 + 4 + 8,
),
subelements=[firmware.SubElement(tag_id=0x0000, data=b"fw_image")],
),
) )
fw_image.header = fw_header
fw_image.should_update = MagicMock(return_value=True)
cached_image = CachedImage(fw_image)
cluster.endpoint.device.application.ota.get_ota_image = AsyncMock( cluster.endpoint.device.application.ota.get_ota_image = AsyncMock(
return_value=None if file_not_found else cached_image return_value=None if file_not_found else fw_image
) )
zha_device = await zha_device_joined_restored(zigpy_device) zha_device = await zha_device_joined_restored(zigpy_device)
@ -108,18 +112,15 @@ async def setup_test_data(
return zha_device, cluster, fw_image, installed_fw_version return zha_device, cluster, fw_image, installed_fw_version
@pytest.mark.parametrize("initial_version_unknown", (False, True))
async def test_firmware_update_notification_from_zigpy( async def test_firmware_update_notification_from_zigpy(
hass: HomeAssistant, hass: HomeAssistant,
zha_device_joined_restored, zha_device_joined_restored,
zigpy_device, zigpy_device,
initial_version_unknown,
) -> None: ) -> None:
"""Test ZHA update platform - firmware update notification.""" """Test ZHA update platform - firmware update notification."""
zha_device, cluster, fw_image, installed_fw_version = await setup_test_data( zha_device, cluster, fw_image, installed_fw_version = await setup_test_data(
zha_device_joined_restored, zha_device_joined_restored,
zigpy_device, zigpy_device,
skip_attribute_plugs=initial_version_unknown,
) )
entity_id = find_entity_id(Platform.UPDATE, zha_device, hass) entity_id = find_entity_id(Platform.UPDATE, zha_device, hass)
@ -132,12 +133,16 @@ async def test_firmware_update_notification_from_zigpy(
# simulate an image available notification # simulate an image available notification
await cluster._handle_query_next_image( await cluster._handle_query_next_image(
fw_image.header.field_control, foundation.ZCLHeader.cluster(
zha_device.manufacturer_code, tsn=0x12, command_id=general.Ota.ServerCommandDefs.query_next_image.id
fw_image.header.image_type, ),
installed_fw_version, general.QueryNextImageCommand(
fw_image.header.header_version, fw_image.firmware.header.field_control,
tsn=15, zha_device.manufacturer_code,
fw_image.firmware.header.image_type,
installed_fw_version,
fw_image.firmware.header.header_version,
),
) )
await hass.async_block_till_done() await hass.async_block_till_done()
@ -146,7 +151,9 @@ async def test_firmware_update_notification_from_zigpy(
attrs = state.attributes attrs = state.attributes
assert attrs[ATTR_INSTALLED_VERSION] == f"0x{installed_fw_version:08x}" assert attrs[ATTR_INSTALLED_VERSION] == f"0x{installed_fw_version:08x}"
assert not attrs[ATTR_IN_PROGRESS] assert not attrs[ATTR_IN_PROGRESS]
assert attrs[ATTR_LATEST_VERSION] == f"0x{fw_image.header.file_version:08x}" assert (
attrs[ATTR_LATEST_VERSION] == f"0x{fw_image.firmware.header.file_version:08x}"
)
async def test_firmware_update_notification_from_service_call( async def test_firmware_update_notification_from_service_call(
@ -167,35 +174,46 @@ async def test_firmware_update_notification_from_service_call(
async def _async_image_notify_side_effect(*args, **kwargs): async def _async_image_notify_side_effect(*args, **kwargs):
await cluster._handle_query_next_image( await cluster._handle_query_next_image(
fw_image.header.field_control, foundation.ZCLHeader.cluster(
zha_device.manufacturer_code, tsn=0x12, command_id=general.Ota.ServerCommandDefs.query_next_image.id
fw_image.header.image_type, ),
installed_fw_version, general.QueryNextImageCommand(
fw_image.header.header_version, fw_image.firmware.header.field_control,
tsn=15, zha_device.manufacturer_code,
fw_image.firmware.header.image_type,
installed_fw_version,
fw_image.firmware.header.header_version,
),
) )
await async_setup_component(hass, HA_DOMAIN, {}) await async_setup_component(hass, HA_DOMAIN, {})
cluster.image_notify = AsyncMock(side_effect=_async_image_notify_side_effect) with patch(
await hass.services.async_call( "zigpy.ota.OTA.broadcast_notify", side_effect=_async_image_notify_side_effect
HA_DOMAIN, ):
SERVICE_UPDATE_ENTITY, await hass.services.async_call(
service_data={ATTR_ENTITY_ID: entity_id}, HA_DOMAIN,
blocking=True, SERVICE_UPDATE_ENTITY,
) service_data={ATTR_ENTITY_ID: entity_id},
assert cluster.image_notify.await_count == 1 blocking=True,
assert cluster.image_notify.call_args_list[0] == call( )
payload_type=cluster.ImageNotifyCommand.PayloadType.QueryJitter,
query_jitter=100,
)
await hass.async_block_till_done() assert cluster.endpoint.device.application.ota.broadcast_notify.await_count == 1
state = hass.states.get(entity_id) assert cluster.endpoint.device.application.ota.broadcast_notify.call_args_list[
assert state.state == STATE_ON 0
attrs = state.attributes ] == call(
assert attrs[ATTR_INSTALLED_VERSION] == f"0x{installed_fw_version:08x}" jitter=100,
assert not attrs[ATTR_IN_PROGRESS] )
assert attrs[ATTR_LATEST_VERSION] == f"0x{fw_image.header.file_version:08x}"
await hass.async_block_till_done()
state = hass.states.get(entity_id)
assert state.state == STATE_ON
attrs = state.attributes
assert attrs[ATTR_INSTALLED_VERSION] == f"0x{installed_fw_version:08x}"
assert not attrs[ATTR_IN_PROGRESS]
assert (
attrs[ATTR_LATEST_VERSION]
== f"0x{fw_image.firmware.header.file_version:08x}"
)
def make_packet(zigpy_device, cluster, cmd_name: str, **kwargs): def make_packet(zigpy_device, cluster, cmd_name: str, **kwargs):
@ -226,6 +244,7 @@ def make_packet(zigpy_device, cluster, cmd_name: str, **kwargs):
return ota_packet return ota_packet
@patch("zigpy.device.AFTER_OTA_ATTR_READ_DELAY", 0.01)
async def test_firmware_update_success( async def test_firmware_update_success(
hass: HomeAssistant, zha_device_joined_restored, zigpy_device hass: HomeAssistant, zha_device_joined_restored, zigpy_device
) -> None: ) -> None:
@ -234,6 +253,8 @@ async def test_firmware_update_success(
zha_device_joined_restored, zigpy_device zha_device_joined_restored, zigpy_device
) )
assert installed_fw_version < fw_image.firmware.header.file_version
entity_id = find_entity_id(Platform.UPDATE, zha_device, hass) entity_id = find_entity_id(Platform.UPDATE, zha_device, hass)
assert entity_id is not None assert entity_id is not None
@ -244,12 +265,15 @@ async def test_firmware_update_success(
# simulate an image available notification # simulate an image available notification
await cluster._handle_query_next_image( await cluster._handle_query_next_image(
fw_image.header.field_control, foundation.ZCLHeader.cluster(
zha_device.manufacturer_code, tsn=0x12, command_id=general.Ota.ServerCommandDefs.query_next_image.id
fw_image.header.image_type, ),
installed_fw_version, general.QueryNextImageCommand(
fw_image.header.header_version, field_control=fw_image.firmware.header.field_control,
tsn=15, manufacturer_code=zha_device.manufacturer_code,
image_type=fw_image.firmware.header.image_type,
current_file_version=installed_fw_version,
),
) )
await hass.async_block_till_done() await hass.async_block_till_done()
@ -258,7 +282,9 @@ async def test_firmware_update_success(
attrs = state.attributes attrs = state.attributes
assert attrs[ATTR_INSTALLED_VERSION] == f"0x{installed_fw_version:08x}" assert attrs[ATTR_INSTALLED_VERSION] == f"0x{installed_fw_version:08x}"
assert not attrs[ATTR_IN_PROGRESS] assert not attrs[ATTR_IN_PROGRESS]
assert attrs[ATTR_LATEST_VERSION] == f"0x{fw_image.header.file_version:08x}" assert (
attrs[ATTR_LATEST_VERSION] == f"0x{fw_image.firmware.header.file_version:08x}"
)
async def endpoint_reply(cluster_id, tsn, data, command_id): async def endpoint_reply(cluster_id, tsn, data, command_id):
if cluster_id == general.Ota.cluster_id: if cluster_id == general.Ota.cluster_id:
@ -270,9 +296,9 @@ async def test_firmware_update_success(
cluster, cluster,
general.Ota.ServerCommandDefs.query_next_image.name, general.Ota.ServerCommandDefs.query_next_image.name,
field_control=general.Ota.QueryNextImageCommand.FieldControl.HardwareVersion, field_control=general.Ota.QueryNextImageCommand.FieldControl.HardwareVersion,
manufacturer_code=fw_image.header.manufacturer_id, manufacturer_code=fw_image.firmware.header.manufacturer_id,
image_type=fw_image.header.image_type, image_type=fw_image.firmware.header.image_type,
current_file_version=fw_image.header.file_version - 10, current_file_version=fw_image.firmware.header.file_version - 10,
hardware_version=1, hardware_version=1,
) )
) )
@ -280,19 +306,19 @@ async def test_firmware_update_success(
cmd, general.Ota.ClientCommandDefs.query_next_image_response.schema cmd, general.Ota.ClientCommandDefs.query_next_image_response.schema
): ):
assert cmd.status == foundation.Status.SUCCESS assert cmd.status == foundation.Status.SUCCESS
assert cmd.manufacturer_code == fw_image.header.manufacturer_id assert cmd.manufacturer_code == fw_image.firmware.header.manufacturer_id
assert cmd.image_type == fw_image.header.image_type assert cmd.image_type == fw_image.firmware.header.image_type
assert cmd.file_version == fw_image.header.file_version assert cmd.file_version == fw_image.firmware.header.file_version
assert cmd.image_size == fw_image.header.image_size assert cmd.image_size == fw_image.firmware.header.image_size
zigpy_device.packet_received( zigpy_device.packet_received(
make_packet( make_packet(
zigpy_device, zigpy_device,
cluster, cluster,
general.Ota.ServerCommandDefs.image_block.name, general.Ota.ServerCommandDefs.image_block.name,
field_control=general.Ota.ImageBlockCommand.FieldControl.RequestNodeAddr, field_control=general.Ota.ImageBlockCommand.FieldControl.RequestNodeAddr,
manufacturer_code=fw_image.header.manufacturer_id, manufacturer_code=fw_image.firmware.header.manufacturer_id,
image_type=fw_image.header.image_type, image_type=fw_image.firmware.header.image_type,
file_version=fw_image.header.file_version, file_version=fw_image.firmware.header.file_version,
file_offset=0, file_offset=0,
maximum_data_size=40, maximum_data_size=40,
request_node_addr=zigpy_device.ieee, request_node_addr=zigpy_device.ieee,
@ -303,20 +329,23 @@ async def test_firmware_update_success(
): ):
if cmd.file_offset == 0: if cmd.file_offset == 0:
assert cmd.status == foundation.Status.SUCCESS assert cmd.status == foundation.Status.SUCCESS
assert cmd.manufacturer_code == fw_image.header.manufacturer_id assert (
assert cmd.image_type == fw_image.header.image_type cmd.manufacturer_code
assert cmd.file_version == fw_image.header.file_version == fw_image.firmware.header.manufacturer_id
)
assert cmd.image_type == fw_image.firmware.header.image_type
assert cmd.file_version == fw_image.firmware.header.file_version
assert cmd.file_offset == 0 assert cmd.file_offset == 0
assert cmd.image_data == fw_image.serialize()[0:40] assert cmd.image_data == fw_image.firmware.serialize()[0:40]
zigpy_device.packet_received( zigpy_device.packet_received(
make_packet( make_packet(
zigpy_device, zigpy_device,
cluster, cluster,
general.Ota.ServerCommandDefs.image_block.name, general.Ota.ServerCommandDefs.image_block.name,
field_control=general.Ota.ImageBlockCommand.FieldControl.RequestNodeAddr, field_control=general.Ota.ImageBlockCommand.FieldControl.RequestNodeAddr,
manufacturer_code=fw_image.header.manufacturer_id, manufacturer_code=fw_image.firmware.header.manufacturer_id,
image_type=fw_image.header.image_type, image_type=fw_image.firmware.header.image_type,
file_version=fw_image.header.file_version, file_version=fw_image.firmware.header.file_version,
file_offset=40, file_offset=40,
maximum_data_size=40, maximum_data_size=40,
request_node_addr=zigpy_device.ieee, request_node_addr=zigpy_device.ieee,
@ -324,11 +353,14 @@ async def test_firmware_update_success(
) )
elif cmd.file_offset == 40: elif cmd.file_offset == 40:
assert cmd.status == foundation.Status.SUCCESS assert cmd.status == foundation.Status.SUCCESS
assert cmd.manufacturer_code == fw_image.header.manufacturer_id assert (
assert cmd.image_type == fw_image.header.image_type cmd.manufacturer_code
assert cmd.file_version == fw_image.header.file_version == fw_image.firmware.header.manufacturer_id
)
assert cmd.image_type == fw_image.firmware.header.image_type
assert cmd.file_version == fw_image.firmware.header.file_version
assert cmd.file_offset == 40 assert cmd.file_offset == 40
assert cmd.image_data == fw_image.serialize()[40:70] assert cmd.image_data == fw_image.firmware.serialize()[40:70]
# make sure the state machine gets progress reports # make sure the state machine gets progress reports
state = hass.states.get(entity_id) state = hass.states.get(entity_id)
@ -337,10 +369,10 @@ async def test_firmware_update_success(
assert ( assert (
attrs[ATTR_INSTALLED_VERSION] == f"0x{installed_fw_version:08x}" attrs[ATTR_INSTALLED_VERSION] == f"0x{installed_fw_version:08x}"
) )
assert attrs[ATTR_IN_PROGRESS] == 57 assert attrs[ATTR_IN_PROGRESS] == 58
assert ( assert (
attrs[ATTR_LATEST_VERSION] attrs[ATTR_LATEST_VERSION]
== f"0x{fw_image.header.file_version:08x}" == f"0x{fw_image.firmware.header.file_version:08x}"
) )
zigpy_device.packet_received( zigpy_device.packet_received(
@ -349,21 +381,34 @@ async def test_firmware_update_success(
cluster, cluster,
general.Ota.ServerCommandDefs.upgrade_end.name, general.Ota.ServerCommandDefs.upgrade_end.name,
status=foundation.Status.SUCCESS, status=foundation.Status.SUCCESS,
manufacturer_code=fw_image.header.manufacturer_id, manufacturer_code=fw_image.firmware.header.manufacturer_id,
image_type=fw_image.header.image_type, image_type=fw_image.firmware.header.image_type,
file_version=fw_image.header.file_version, file_version=fw_image.firmware.header.file_version,
) )
) )
elif isinstance( elif isinstance(
cmd, general.Ota.ClientCommandDefs.upgrade_end_response.schema cmd, general.Ota.ClientCommandDefs.upgrade_end_response.schema
): ):
assert cmd.manufacturer_code == fw_image.header.manufacturer_id assert cmd.manufacturer_code == fw_image.firmware.header.manufacturer_id
assert cmd.image_type == fw_image.header.image_type assert cmd.image_type == fw_image.firmware.header.image_type
assert cmd.file_version == fw_image.header.file_version assert cmd.file_version == fw_image.firmware.header.file_version
assert cmd.current_time == 0 assert cmd.current_time == 0
assert cmd.upgrade_time == 0 assert cmd.upgrade_time == 0
def read_new_fw_version(*args, **kwargs):
cluster.update_attribute(
attrid=general.Ota.AttributeDefs.current_file_version.id,
value=fw_image.firmware.header.file_version,
)
return {
general.Ota.AttributeDefs.current_file_version.id: (
fw_image.firmware.header.file_version
)
}, {}
cluster.read_attributes.side_effect = read_new_fw_version
cluster.endpoint.reply = AsyncMock(side_effect=endpoint_reply) cluster.endpoint.reply = AsyncMock(side_effect=endpoint_reply)
await hass.services.async_call( await hass.services.async_call(
UPDATE_DOMAIN, UPDATE_DOMAIN,
@ -377,10 +422,21 @@ async def test_firmware_update_success(
state = hass.states.get(entity_id) state = hass.states.get(entity_id)
assert state.state == STATE_OFF assert state.state == STATE_OFF
attrs = state.attributes attrs = state.attributes
assert attrs[ATTR_INSTALLED_VERSION] == f"0x{fw_image.header.file_version:08x}" assert (
attrs[ATTR_INSTALLED_VERSION]
== f"0x{fw_image.firmware.header.file_version:08x}"
)
assert not attrs[ATTR_IN_PROGRESS] assert not attrs[ATTR_IN_PROGRESS]
assert attrs[ATTR_LATEST_VERSION] == attrs[ATTR_INSTALLED_VERSION] assert attrs[ATTR_LATEST_VERSION] == attrs[ATTR_INSTALLED_VERSION]
# If we send a progress notification incorrectly, it won't be handled
entity = hass.data[UPDATE_DOMAIN].get_entity(entity_id)
entity._update_progress(50, 100, 0.50)
state = hass.states.get(entity_id)
assert not attrs[ATTR_IN_PROGRESS]
assert state.state == STATE_OFF
async def test_firmware_update_raises( async def test_firmware_update_raises(
hass: HomeAssistant, zha_device_joined_restored, zigpy_device hass: HomeAssistant, zha_device_joined_restored, zigpy_device
@ -400,12 +456,16 @@ async def test_firmware_update_raises(
# simulate an image available notification # simulate an image available notification
await cluster._handle_query_next_image( await cluster._handle_query_next_image(
fw_image.header.field_control, foundation.ZCLHeader.cluster(
zha_device.manufacturer_code, tsn=0x12, command_id=general.Ota.ServerCommandDefs.query_next_image.id
fw_image.header.image_type, ),
installed_fw_version, general.QueryNextImageCommand(
fw_image.header.header_version, fw_image.firmware.header.field_control,
tsn=15, zha_device.manufacturer_code,
fw_image.firmware.header.image_type,
installed_fw_version,
fw_image.firmware.header.header_version,
),
) )
await hass.async_block_till_done() await hass.async_block_till_done()
@ -414,7 +474,9 @@ async def test_firmware_update_raises(
attrs = state.attributes attrs = state.attributes
assert attrs[ATTR_INSTALLED_VERSION] == f"0x{installed_fw_version:08x}" assert attrs[ATTR_INSTALLED_VERSION] == f"0x{installed_fw_version:08x}"
assert not attrs[ATTR_IN_PROGRESS] assert not attrs[ATTR_IN_PROGRESS]
assert attrs[ATTR_LATEST_VERSION] == f"0x{fw_image.header.file_version:08x}" assert (
attrs[ATTR_LATEST_VERSION] == f"0x{fw_image.firmware.header.file_version:08x}"
)
async def endpoint_reply(cluster_id, tsn, data, command_id): async def endpoint_reply(cluster_id, tsn, data, command_id):
if cluster_id == general.Ota.cluster_id: if cluster_id == general.Ota.cluster_id:
@ -426,9 +488,9 @@ async def test_firmware_update_raises(
cluster, cluster,
general.Ota.ServerCommandDefs.query_next_image.name, general.Ota.ServerCommandDefs.query_next_image.name,
field_control=general.Ota.QueryNextImageCommand.FieldControl.HardwareVersion, field_control=general.Ota.QueryNextImageCommand.FieldControl.HardwareVersion,
manufacturer_code=fw_image.header.manufacturer_id, manufacturer_code=fw_image.firmware.header.manufacturer_id,
image_type=fw_image.header.image_type, image_type=fw_image.firmware.header.image_type,
current_file_version=fw_image.header.file_version - 10, current_file_version=fw_image.firmware.header.file_version - 10,
hardware_version=1, hardware_version=1,
) )
) )
@ -436,10 +498,10 @@ async def test_firmware_update_raises(
cmd, general.Ota.ClientCommandDefs.query_next_image_response.schema cmd, general.Ota.ClientCommandDefs.query_next_image_response.schema
): ):
assert cmd.status == foundation.Status.SUCCESS assert cmd.status == foundation.Status.SUCCESS
assert cmd.manufacturer_code == fw_image.header.manufacturer_id assert cmd.manufacturer_code == fw_image.firmware.header.manufacturer_id
assert cmd.image_type == fw_image.header.image_type assert cmd.image_type == fw_image.firmware.header.image_type
assert cmd.file_version == fw_image.header.file_version assert cmd.file_version == fw_image.firmware.header.file_version
assert cmd.image_size == fw_image.header.image_size assert cmd.image_size == fw_image.firmware.header.image_size
raise DeliveryError("failed to deliver") raise DeliveryError("failed to deliver")
cluster.endpoint.reply = AsyncMock(side_effect=endpoint_reply) cluster.endpoint.reply = AsyncMock(side_effect=endpoint_reply)
@ -467,29 +529,10 @@ async def test_firmware_update_raises(
) )
async def test_firmware_update_restore_data( async def test_firmware_update_no_longer_compatible(
hass: HomeAssistant, zha_device_joined_restored, zigpy_device hass: HomeAssistant, zha_device_joined_restored, zigpy_device
) -> None: ) -> None:
"""Test ZHA update platform - restore data.""" """Test ZHA update platform - firmware update is no longer valid."""
fw_version = 0x12345678
installed_fw_version = fw_version - 10
mock_restore_cache_with_extra_data(
hass,
[
(
State(
"update.fakemanufacturer_fakemodel_firmware",
STATE_ON,
{
ATTR_INSTALLED_VERSION: f"0x{installed_fw_version:08x}",
ATTR_LATEST_VERSION: f"0x{fw_version:08x}",
ATTR_SKIPPED_VERSION: None,
},
),
{"image_type": 0x90},
)
],
)
zha_device, cluster, fw_image, installed_fw_version = await setup_test_data( zha_device, cluster, fw_image, installed_fw_version = await setup_test_data(
zha_device_joined_restored, zigpy_device zha_device_joined_restored, zigpy_device
) )
@ -500,94 +543,67 @@ async def test_firmware_update_restore_data(
# allow traffic to flow through the gateway and device # allow traffic to flow through the gateway and device
await async_enable_traffic(hass, [zha_device]) await async_enable_traffic(hass, [zha_device])
assert hass.states.get(entity_id).state == STATE_OFF
# simulate an image available notification
await cluster._handle_query_next_image(
foundation.ZCLHeader.cluster(
tsn=0x12, command_id=general.Ota.ServerCommandDefs.query_next_image.id
),
general.QueryNextImageCommand(
fw_image.firmware.header.field_control,
zha_device.manufacturer_code,
fw_image.firmware.header.image_type,
installed_fw_version,
fw_image.firmware.header.header_version,
),
)
await hass.async_block_till_done()
state = hass.states.get(entity_id) state = hass.states.get(entity_id)
assert state.state == STATE_ON assert state.state == STATE_ON
attrs = state.attributes attrs = state.attributes
assert attrs[ATTR_INSTALLED_VERSION] == f"0x{installed_fw_version:08x}" assert attrs[ATTR_INSTALLED_VERSION] == f"0x{installed_fw_version:08x}"
assert not attrs[ATTR_IN_PROGRESS] assert not attrs[ATTR_IN_PROGRESS]
assert attrs[ATTR_LATEST_VERSION] == f"0x{fw_image.header.file_version:08x}" assert (
attrs[ATTR_LATEST_VERSION] == f"0x{fw_image.firmware.header.file_version:08x}"
async def test_firmware_update_restore_file_not_found(
hass: HomeAssistant, zha_device_joined_restored, zigpy_device
) -> None:
"""Test ZHA update platform - restore data - file not found."""
fw_version = 0x12345678
installed_fw_version = fw_version - 10
mock_restore_cache_with_extra_data(
hass,
[
(
State(
"update.fakemanufacturer_fakemodel_firmware",
STATE_ON,
{
ATTR_INSTALLED_VERSION: f"0x{installed_fw_version:08x}",
ATTR_LATEST_VERSION: f"0x{fw_version:08x}",
ATTR_SKIPPED_VERSION: None,
},
),
{"image_type": 0x90},
)
],
)
zha_device, cluster, fw_image, installed_fw_version = await setup_test_data(
zha_device_joined_restored, zigpy_device, file_not_found=True
) )
entity_id = find_entity_id(Platform.UPDATE, zha_device, hass) new_version = 0x99999999
assert entity_id is not None
# allow traffic to flow through the gateway and device async def endpoint_reply(cluster_id, tsn, data, command_id):
await async_enable_traffic(hass, [zha_device]) if cluster_id == general.Ota.cluster_id:
hdr, cmd = cluster.deserialize(data)
if isinstance(cmd, general.Ota.ImageNotifyCommand):
zigpy_device.packet_received(
make_packet(
zigpy_device,
cluster,
general.Ota.ServerCommandDefs.query_next_image.name,
field_control=general.Ota.QueryNextImageCommand.FieldControl.HardwareVersion,
manufacturer_code=fw_image.firmware.header.manufacturer_id,
image_type=fw_image.firmware.header.image_type,
# The device reports that it is no longer compatible!
current_file_version=new_version,
hardware_version=1,
)
)
cluster.endpoint.reply = AsyncMock(side_effect=endpoint_reply)
with pytest.raises(HomeAssistantError):
await hass.services.async_call(
UPDATE_DOMAIN,
SERVICE_INSTALL,
{
ATTR_ENTITY_ID: entity_id,
},
blocking=True,
)
# We updated the currently installed firmware version, as it is no longer valid
state = hass.states.get(entity_id) state = hass.states.get(entity_id)
assert state.state == STATE_OFF assert state.state == STATE_OFF
attrs = state.attributes attrs = state.attributes
assert attrs[ATTR_INSTALLED_VERSION] == f"0x{installed_fw_version:08x}" assert attrs[ATTR_INSTALLED_VERSION] == f"0x{new_version:08x}"
assert not attrs[ATTR_IN_PROGRESS] assert not attrs[ATTR_IN_PROGRESS]
assert attrs[ATTR_LATEST_VERSION] == f"0x{installed_fw_version:08x}" assert attrs[ATTR_LATEST_VERSION] == f"0x{new_version:08x}"
async def test_firmware_update_restore_version_from_state_machine(
hass: HomeAssistant, zha_device_joined_restored, zigpy_device
) -> None:
"""Test ZHA update platform - restore data - file not found."""
fw_version = 0x12345678
installed_fw_version = fw_version - 10
mock_restore_cache_with_extra_data(
hass,
[
(
State(
"update.fakemanufacturer_fakemodel_firmware",
STATE_ON,
{
ATTR_INSTALLED_VERSION: f"0x{installed_fw_version:08x}",
ATTR_LATEST_VERSION: f"0x{fw_version:08x}",
ATTR_SKIPPED_VERSION: None,
},
),
{"image_type": 0x90},
)
],
)
zha_device, cluster, fw_image, installed_fw_version = await setup_test_data(
zha_device_joined_restored,
zigpy_device,
skip_attribute_plugs=True,
file_not_found=True,
)
entity_id = find_entity_id(Platform.UPDATE, zha_device, hass)
assert entity_id is not None
# allow traffic to flow through the gateway and device
await async_enable_traffic(hass, [zha_device])
state = hass.states.get(entity_id)
assert state.state == STATE_OFF
attrs = state.attributes
assert attrs[ATTR_INSTALLED_VERSION] == f"0x{installed_fw_version:08x}"
assert not attrs[ATTR_IN_PROGRESS]
assert attrs[ATTR_LATEST_VERSION] == f"0x{installed_fw_version:08x}"