Fix invalid unique id for Transmission entities (#84664)

* Update unique id for Transmission entities

* Moved migration to a separate function

* Hopefully fixed coverage

* Extracted dictionary to constant

* review comments

* more comments

* revert accidental name change

* more review comments

* more review comments

* use lists instead of incorrect tuple syntax
This commit is contained in:
avee87 2023-06-28 09:45:13 +01:00 committed by GitHub
parent 2747da784c
commit a5b91cb7e3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 140 additions and 43 deletions

View File

@ -4,6 +4,7 @@ from __future__ import annotations
from datetime import timedelta from datetime import timedelta
from functools import partial from functools import partial
import logging import logging
import re
from typing import Any from typing import Any
import transmission_rpc import transmission_rpc
@ -18,15 +19,20 @@ from homeassistant.config_entries import ConfigEntry, ConfigEntryState
from homeassistant.const import ( from homeassistant.const import (
CONF_HOST, CONF_HOST,
CONF_ID, CONF_ID,
CONF_NAME,
CONF_PASSWORD, CONF_PASSWORD,
CONF_PORT, CONF_PORT,
CONF_SCAN_INTERVAL, CONF_SCAN_INTERVAL,
CONF_USERNAME, CONF_USERNAME,
Platform, Platform,
) )
from homeassistant.core import HomeAssistant, ServiceCall from homeassistant.core import HomeAssistant, ServiceCall, callback
from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady
from homeassistant.helpers import config_validation as cv, selector from homeassistant.helpers import (
config_validation as cv,
entity_registry as er,
selector,
)
from homeassistant.helpers.dispatcher import dispatcher_send from homeassistant.helpers.dispatcher import dispatcher_send
from homeassistant.helpers.event import async_track_time_interval from homeassistant.helpers.event import async_track_time_interval
@ -91,9 +97,41 @@ CONFIG_SCHEMA = cv.removed(DOMAIN, raise_if_present=False)
PLATFORMS = [Platform.SENSOR, Platform.SWITCH] PLATFORMS = [Platform.SENSOR, Platform.SWITCH]
MIGRATION_NAME_TO_KEY = {
# Sensors
"Down Speed": "download",
"Up Speed": "upload",
"Status": "status",
"Active Torrents": "active_torrents",
"Paused Torrents": "paused_torrents",
"Total Torrents": "total_torrents",
"Completed Torrents": "completed_torrents",
"Started Torrents": "started_torrents",
# Switches
"Switch": "on_off",
"Turtle Mode": "turtle_mode",
}
async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool: async def async_setup_entry(hass: HomeAssistant, config_entry: ConfigEntry) -> bool:
"""Set up the Transmission Component.""" """Set up the Transmission Component."""
@callback
def update_unique_id(
entity_entry: er.RegistryEntry,
) -> dict[str, Any] | None:
"""Update unique ID of entity entry."""
match = re.search(
f"{config_entry.data[CONF_HOST]}-{config_entry.data[CONF_NAME]} (?P<name>.+)",
entity_entry.unique_id,
)
if match and (key := MIGRATION_NAME_TO_KEY.get(match.group("name"))):
return {"new_unique_id": f"{config_entry.entry_id}-{key}"}
return None
await er.async_migrate_entries(hass, config_entry.entry_id, update_unique_id)
client = TransmissionClient(hass, config_entry) client = TransmissionClient(hass, config_entry)
hass.data.setdefault(DOMAIN, {})[config_entry.entry_id] = client hass.data.setdefault(DOMAIN, {})[config_entry.entry_id] = client

View File

@ -40,12 +40,20 @@ async def async_setup_entry(
dev = [ dev = [
TransmissionSpeedSensor(tm_client, name, "Down Speed", "download"), TransmissionSpeedSensor(tm_client, name, "Down Speed", "download"),
TransmissionSpeedSensor(tm_client, name, "Up Speed", "upload"), TransmissionSpeedSensor(tm_client, name, "Up Speed", "upload"),
TransmissionStatusSensor(tm_client, name, "Status"), TransmissionStatusSensor(tm_client, name, "Status", "status"),
TransmissionTorrentsSensor(tm_client, name, "Active Torrents", "active"), TransmissionTorrentsSensor(
TransmissionTorrentsSensor(tm_client, name, "Paused Torrents", "paused"), tm_client, name, "Active Torrents", "active_torrents"
TransmissionTorrentsSensor(tm_client, name, "Total Torrents", "total"), ),
TransmissionTorrentsSensor(tm_client, name, "Completed Torrents", "completed"), TransmissionTorrentsSensor(
TransmissionTorrentsSensor(tm_client, name, "Started Torrents", "started"), tm_client, name, "Paused Torrents", "paused_torrents"
),
TransmissionTorrentsSensor(tm_client, name, "Total Torrents", "total_torrents"),
TransmissionTorrentsSensor(
tm_client, name, "Completed Torrents", "completed_torrents"
),
TransmissionTorrentsSensor(
tm_client, name, "Started Torrents", "started_torrents"
),
] ]
async_add_entities(dev, True) async_add_entities(dev, True)
@ -56,13 +64,13 @@ class TransmissionSensor(SensorEntity):
_attr_should_poll = False _attr_should_poll = False
def __init__(self, tm_client, client_name, sensor_name, sub_type=None): def __init__(self, tm_client, client_name, sensor_name, key):
"""Initialize the sensor.""" """Initialize the sensor."""
self._tm_client: TransmissionClient = tm_client self._tm_client: TransmissionClient = tm_client
self._client_name = client_name self._attr_name = f"{client_name} {sensor_name}"
self._name = sensor_name self._key = key
self._sub_type = sub_type
self._state = None self._state = None
self._attr_unique_id = f"{tm_client.config_entry.entry_id}-{key}"
self._attr_device_info = DeviceInfo( self._attr_device_info = DeviceInfo(
entry_type=DeviceEntryType.SERVICE, entry_type=DeviceEntryType.SERVICE,
identifiers={(DOMAIN, tm_client.config_entry.entry_id)}, identifiers={(DOMAIN, tm_client.config_entry.entry_id)},
@ -70,16 +78,6 @@ class TransmissionSensor(SensorEntity):
name=client_name, name=client_name,
) )
@property
def name(self):
"""Return the name of the sensor."""
return f"{self._client_name} {self._name}"
@property
def unique_id(self):
"""Return the unique id of the entity."""
return f"{self._tm_client.api.host}-{self.name}"
@property @property
def native_value(self): def native_value(self):
"""Return the state of the sensor.""" """Return the state of the sensor."""
@ -118,7 +116,7 @@ class TransmissionSpeedSensor(TransmissionSensor):
if data := self._tm_client.api.data: if data := self._tm_client.api.data:
b_spd = ( b_spd = (
float(data.download_speed) float(data.download_speed)
if self._sub_type == "download" if self._key == "download"
else float(data.upload_speed) else float(data.upload_speed)
) )
self._state = b_spd self._state = b_spd
@ -151,12 +149,15 @@ class TransmissionStatusSensor(TransmissionSensor):
class TransmissionTorrentsSensor(TransmissionSensor): class TransmissionTorrentsSensor(TransmissionSensor):
"""Representation of a Transmission torrents sensor.""" """Representation of a Transmission torrents sensor."""
SUBTYPE_MODES = { MODES: dict[str, list[str] | None] = {
"started": ("downloading"), "started_torrents": ["downloading"],
"completed": ("seeding"), "completed_torrents": ["seeding"],
"paused": ("stopped"), "paused_torrents": ["stopped"],
"active": ("seeding", "downloading"), "active_torrents": [
"total": None, "seeding",
"downloading",
],
"total_torrents": None,
} }
@property @property
@ -171,7 +172,7 @@ class TransmissionTorrentsSensor(TransmissionSensor):
torrents=self._tm_client.api.torrents, torrents=self._tm_client.api.torrents,
order=self._tm_client.config_entry.options[CONF_ORDER], order=self._tm_client.config_entry.options[CONF_ORDER],
limit=self._tm_client.config_entry.options[CONF_LIMIT], limit=self._tm_client.config_entry.options[CONF_LIMIT],
statuses=self.SUBTYPE_MODES[self._sub_type], statuses=self.MODES[self._key],
) )
return { return {
STATE_ATTR_TORRENT_INFO: info, STATE_ATTR_TORRENT_INFO: info,
@ -180,7 +181,7 @@ class TransmissionTorrentsSensor(TransmissionSensor):
def update(self) -> None: def update(self) -> None:
"""Get the latest data from Transmission and updates the state.""" """Get the latest data from Transmission and updates the state."""
torrents = _filter_torrents( torrents = _filter_torrents(
self._tm_client.api.torrents, statuses=self.SUBTYPE_MODES[self._sub_type] self._tm_client.api.torrents, statuses=self.MODES[self._key]
) )
self._state = len(torrents) self._state = len(torrents)

View File

@ -40,13 +40,13 @@ class TransmissionSwitch(SwitchEntity):
def __init__(self, switch_type, switch_name, tm_client, client_name): def __init__(self, switch_type, switch_name, tm_client, client_name):
"""Initialize the Transmission switch.""" """Initialize the Transmission switch."""
self._name = switch_name self._attr_name = f"{client_name} {switch_name}"
self.client_name = client_name
self.type = switch_type self.type = switch_type
self._tm_client = tm_client self._tm_client = tm_client
self._state = STATE_OFF self._state = STATE_OFF
self._data = None self._data = None
self.unsub_update = None self.unsub_update = None
self._attr_unique_id = f"{tm_client.config_entry.entry_id}-{switch_type}"
self._attr_device_info = DeviceInfo( self._attr_device_info = DeviceInfo(
entry_type=DeviceEntryType.SERVICE, entry_type=DeviceEntryType.SERVICE,
identifiers={(DOMAIN, tm_client.config_entry.entry_id)}, identifiers={(DOMAIN, tm_client.config_entry.entry_id)},
@ -54,16 +54,6 @@ class TransmissionSwitch(SwitchEntity):
name=client_name, name=client_name,
) )
@property
def name(self):
"""Return the name of the switch."""
return f"{self.client_name} {self._name}"
@property
def unique_id(self):
"""Return the unique id of the entity."""
return f"{self._tm_client.api.host}-{self.name}"
@property @property
def is_on(self): def is_on(self):
"""Return true if device is on.""" """Return true if device is on."""

View File

@ -9,9 +9,12 @@ from transmission_rpc.error import (
TransmissionError, TransmissionError,
) )
from homeassistant.components.sensor import DOMAIN as SENSOR_DOMAIN
from homeassistant.components.switch import DOMAIN as SWITCH_DOMAIN
from homeassistant.components.transmission.const import DOMAIN from homeassistant.components.transmission.const import DOMAIN
from homeassistant.config_entries import ConfigEntryState from homeassistant.config_entries import ConfigEntryState
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import entity_registry as er
from . import MOCK_CONFIG_DATA from . import MOCK_CONFIG_DATA
@ -91,3 +94,68 @@ async def test_unload_entry(hass: HomeAssistant) -> None:
assert entry.state is ConfigEntryState.NOT_LOADED assert entry.state is ConfigEntryState.NOT_LOADED
assert not hass.data[DOMAIN] assert not hass.data[DOMAIN]
@pytest.mark.parametrize(
("domain", "old_unique_id", "new_unique_id"),
[
(SENSOR_DOMAIN, "0.0.0.0-Transmission Down Speed", "1234-download"),
(SENSOR_DOMAIN, "0.0.0.0-Transmission Up Speed", "1234-upload"),
(SENSOR_DOMAIN, "0.0.0.0-Transmission Status", "1234-status"),
(
SENSOR_DOMAIN,
"0.0.0.0-Transmission Active Torrents",
"1234-active_torrents",
),
(
SENSOR_DOMAIN,
"0.0.0.0-Transmission Paused Torrents",
"1234-paused_torrents",
),
(SENSOR_DOMAIN, "0.0.0.0-Transmission Total Torrents", "1234-total_torrents"),
(
SENSOR_DOMAIN,
"0.0.0.0-Transmission Completed Torrents",
"1234-completed_torrents",
),
(
SENSOR_DOMAIN,
"0.0.0.0-Transmission Started Torrents",
"1234-started_torrents",
),
# no change on correct sensor unique id
(SENSOR_DOMAIN, "1234-started_torrents", "1234-started_torrents"),
(SWITCH_DOMAIN, "0.0.0.0-Transmission Switch", "1234-on_off"),
(SWITCH_DOMAIN, "0.0.0.0-Transmission Turtle Mode", "1234-turtle_mode"),
# no change on correct switch unique id
(SWITCH_DOMAIN, "1234-turtle_mode", "1234-turtle_mode"),
],
)
async def test_migrate_unique_id(
hass: HomeAssistant,
entity_registry: er.EntityRegistry,
domain: str,
old_unique_id: str,
new_unique_id: str,
) -> None:
"""Test unique id migration."""
entry = MockConfigEntry(domain=DOMAIN, data=MOCK_CONFIG_DATA, entry_id="1234")
entry.add_to_hass(hass)
entity: er.RegistryEntry = entity_registry.async_get_or_create(
suggested_object_id=f"my_{domain}",
disabled_by=None,
domain=domain,
platform=DOMAIN,
unique_id=old_unique_id,
config_entry=entry,
)
assert entity.unique_id == old_unique_id
await hass.config_entries.async_setup(entry.entry_id)
await hass.async_block_till_done()
migrated_entity = entity_registry.async_get(entity.entity_id)
assert migrated_entity
assert migrated_entity.unique_id == new_unique_id