Update tplink unlink identifiers to deal with ids from other domains (#120596)

This commit is contained in:
Steven B 2024-06-27 13:54:34 +01:00 committed by GitHub
parent 970dd99226
commit 9758b08036
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 123 additions and 58 deletions

View File

@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
from collections.abc import Iterable
from datetime import timedelta from datetime import timedelta
import logging import logging
from typing import Any from typing import Any
@ -282,6 +283,28 @@ def mac_alias(mac: str) -> str:
return mac.replace(":", "")[-4:].upper() return mac.replace(":", "")[-4:].upper()
def _mac_connection_or_none(device: dr.DeviceEntry) -> str | None:
return next(
(
conn
for type_, conn in device.connections
if type_ == dr.CONNECTION_NETWORK_MAC
),
None,
)
def _device_id_is_mac_or_none(mac: str, device_ids: Iterable[str]) -> str | None:
# Previously only iot devices had child devices and iot devices use
# the upper and lcase MAC addresses as device_id so match on case
# insensitive mac address as the parent device.
upper_mac = mac.upper()
return next(
(device_id for device_id in device_ids if device_id.upper() == upper_mac),
None,
)
async def async_migrate_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool: async def async_migrate_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool:
"""Migrate old entry.""" """Migrate old entry."""
version = config_entry.version version = config_entry.version
@ -298,49 +321,48 @@ async def async_migrate_entry(hass: HomeAssistant, config_entry: ConfigEntry) ->
# always be linked into one device. # always be linked into one device.
dev_reg = dr.async_get(hass) dev_reg = dr.async_get(hass)
for device in dr.async_entries_for_config_entry(dev_reg, config_entry.entry_id): for device in dr.async_entries_for_config_entry(dev_reg, config_entry.entry_id):
new_identifiers: set[tuple[str, str]] | None = None original_identifiers = device.identifiers
if len(device.identifiers) > 1 and ( # Get only the tplink identifier, could be tapo or other integrations.
mac := next( tplink_identifiers = [
iter( ident[1] for ident in original_identifiers if ident[0] == DOMAIN
[ ]
conn[1] # Nothing to fix if there's only one identifier. mac connection
for conn in device.connections # should never be none but if it is there's no problem.
if conn[0] == dr.CONNECTION_NETWORK_MAC if len(tplink_identifiers) <= 1 or not (
] mac := _mac_connection_or_none(device)
), ):
None, continue
if not (
tplink_parent_device_id := _device_id_is_mac_or_none(
mac, tplink_identifiers
) )
): ):
for identifier in device.identifiers: # No match on mac so raise an error.
# Previously only iot devices that use the MAC address as _LOGGER.error(
# device_id had child devices so check for mac as the "Unable to replace identifiers for device %s (%s): %s",
# parent device. device.name,
if identifier[0] == DOMAIN and identifier[1].upper() == mac.upper(): device.model,
new_identifiers = {identifier} device.identifiers,
break )
if new_identifiers: continue
dev_reg.async_update_device( # Retain any identifiers for other domains
device.id, new_identifiers=new_identifiers new_identifiers = {
) ident for ident in device.identifiers if ident[0] != DOMAIN
_LOGGER.debug( }
"Replaced identifiers for device %s (%s): %s with: %s", new_identifiers.add((DOMAIN, tplink_parent_device_id))
device.name, dev_reg.async_update_device(device.id, new_identifiers=new_identifiers)
device.model, _LOGGER.debug(
device.identifiers, "Replaced identifiers for device %s (%s): %s with: %s",
new_identifiers, device.name,
) device.model,
else: original_identifiers,
# No match on mac so raise an error. new_identifiers,
_LOGGER.error( )
"Unable to replace identifiers for device %s (%s): %s",
device.name,
device.model,
device.identifiers,
)
minor_version = 3 minor_version = 3
hass.config_entries.async_update_entry(config_entry, minor_version=3) hass.config_entries.async_update_entry(config_entry, minor_version=3)
_LOGGER.debug("Migration to version %s.%s successful", version, minor_version)
_LOGGER.debug("Migration to version %s.%s complete", version, minor_version)
if version == 1 and minor_version == 3: if version == 1 and minor_version == 3:
# credentials_hash stored in the device_config should be moved to data. # credentials_hash stored in the device_config should be moved to data.

View File

@ -49,6 +49,7 @@ ALIAS = "My Bulb"
MODEL = "HS100" MODEL = "HS100"
MAC_ADDRESS = "aa:bb:cc:dd:ee:ff" MAC_ADDRESS = "aa:bb:cc:dd:ee:ff"
DEVICE_ID = "123456789ABCDEFGH" DEVICE_ID = "123456789ABCDEFGH"
DEVICE_ID_MAC = "AA:BB:CC:DD:EE:FF"
DHCP_FORMATTED_MAC_ADDRESS = MAC_ADDRESS.replace(":", "") DHCP_FORMATTED_MAC_ADDRESS = MAC_ADDRESS.replace(":", "")
MAC_ADDRESS2 = "11:22:33:44:55:66" MAC_ADDRESS2 = "11:22:33:44:55:66"
DEFAULT_ENTRY_TITLE = f"{ALIAS} {MODEL}" DEFAULT_ENTRY_TITLE = f"{ALIAS} {MODEL}"

View File

@ -36,6 +36,8 @@ from . import (
CREATE_ENTRY_DATA_AUTH, CREATE_ENTRY_DATA_AUTH,
CREATE_ENTRY_DATA_LEGACY, CREATE_ENTRY_DATA_LEGACY,
DEVICE_CONFIG_AUTH, DEVICE_CONFIG_AUTH,
DEVICE_ID,
DEVICE_ID_MAC,
IP_ADDRESS, IP_ADDRESS,
MAC_ADDRESS, MAC_ADDRESS,
_mocked_device, _mocked_device,
@ -404,19 +406,48 @@ async def test_feature_no_category(
@pytest.mark.parametrize( @pytest.mark.parametrize(
("identifier_base", "expected_message", "expected_count"), ("device_id", "id_count", "domains", "expected_message"),
[ [
pytest.param("C0:06:C3:42:54:2B", "Replaced", 1, id="success"), pytest.param(DEVICE_ID_MAC, 1, [DOMAIN], None, id="mac-id-no-children"),
pytest.param("123456789", "Unable to replace", 3, id="failure"), pytest.param(DEVICE_ID_MAC, 3, [DOMAIN], "Replaced", id="mac-id-children"),
pytest.param(
DEVICE_ID_MAC,
1,
[DOMAIN, "other"],
None,
id="mac-id-no-children-other-domain",
),
pytest.param(
DEVICE_ID_MAC,
3,
[DOMAIN, "other"],
"Replaced",
id="mac-id-children-other-domain",
),
pytest.param(DEVICE_ID, 1, [DOMAIN], None, id="not-mac-id-no-children"),
pytest.param(
DEVICE_ID, 3, [DOMAIN], "Unable to replace", id="not-mac-children"
),
pytest.param(
DEVICE_ID, 1, [DOMAIN, "other"], None, id="not-mac-no-children-other-domain"
),
pytest.param(
DEVICE_ID,
3,
[DOMAIN, "other"],
"Unable to replace",
id="not-mac-children-other-domain",
),
], ],
) )
async def test_unlink_devices( async def test_unlink_devices(
hass: HomeAssistant, hass: HomeAssistant,
device_registry: dr.DeviceRegistry, device_registry: dr.DeviceRegistry,
caplog: pytest.LogCaptureFixture, caplog: pytest.LogCaptureFixture,
identifier_base, device_id,
id_count,
domains,
expected_message, expected_message,
expected_count,
) -> None: ) -> None:
"""Test for unlinking child device ids.""" """Test for unlinking child device ids."""
entry = MockConfigEntry( entry = MockConfigEntry(
@ -429,43 +460,54 @@ async def test_unlink_devices(
) )
entry.add_to_hass(hass) entry.add_to_hass(hass)
# Setup initial device registry, with linkages # Generate list of test identifiers
mac = "C0:06:C3:42:54:2B" test_identifiers = [
identifiers = [ (domain, f"{device_id}{"" if i == 0 else f"_000{i}"}")
(DOMAIN, identifier_base), for i in range(id_count)
(DOMAIN, f"{identifier_base}_0001"), for domain in domains
(DOMAIN, f"{identifier_base}_0002"),
] ]
update_msg_fragment = "identifiers for device dummy (hs300):"
update_msg = f"{expected_message} {update_msg_fragment}" if expected_message else ""
# Expected identifiers should include all other domains or all the newer non-mac device ids
# or just the parent mac device id
expected_identifiers = [
(domain, device_id)
for domain, device_id in test_identifiers
if domain != DOMAIN
or device_id.startswith(DEVICE_ID)
or device_id == DEVICE_ID_MAC
]
device_registry.async_get_or_create( device_registry.async_get_or_create(
config_entry_id="123456", config_entry_id="123456",
connections={ connections={
(dr.CONNECTION_NETWORK_MAC, mac.lower()), (dr.CONNECTION_NETWORK_MAC, MAC_ADDRESS),
}, },
identifiers=set(identifiers), identifiers=set(test_identifiers),
model="hs300", model="hs300",
name="dummy", name="dummy",
) )
device_entries = dr.async_entries_for_config_entry(device_registry, entry.entry_id) device_entries = dr.async_entries_for_config_entry(device_registry, entry.entry_id)
assert device_entries[0].connections == { assert device_entries[0].connections == {
(dr.CONNECTION_NETWORK_MAC, mac.lower()), (dr.CONNECTION_NETWORK_MAC, MAC_ADDRESS),
} }
assert device_entries[0].identifiers == set(identifiers) assert device_entries[0].identifiers == set(test_identifiers)
await hass.config_entries.async_setup(entry.entry_id) await hass.config_entries.async_setup(entry.entry_id)
await hass.async_block_till_done() await hass.async_block_till_done()
device_entries = dr.async_entries_for_config_entry(device_registry, entry.entry_id) device_entries = dr.async_entries_for_config_entry(device_registry, entry.entry_id)
assert device_entries[0].connections == {(dr.CONNECTION_NETWORK_MAC, mac.lower())} assert device_entries[0].connections == {(dr.CONNECTION_NETWORK_MAC, MAC_ADDRESS)}
# If expected count is 1 will be the first identifier only
expected_identifiers = identifiers[:expected_count]
assert device_entries[0].identifiers == set(expected_identifiers) assert device_entries[0].identifiers == set(expected_identifiers)
assert entry.version == 1 assert entry.version == 1
assert entry.minor_version == 4 assert entry.minor_version == 4
msg = f"{expected_message} identifiers for device dummy (hs300): {set(identifiers)}" assert update_msg in caplog.text
assert msg in caplog.text assert "Migration to version 1.3 complete" in caplog.text
async def test_move_credentials_hash( async def test_move_credentials_hash(