diff --git a/homeassistant/components/swiss_public_transport/__init__.py b/homeassistant/components/swiss_public_transport/__init__.py index 9e01a07416f..a510b5b7414 100644 --- a/homeassistant/components/swiss_public_transport/__init__.py +++ b/homeassistant/components/swiss_public_transport/__init__.py @@ -10,6 +10,7 @@ from opendata_transport.exceptions import ( from homeassistant import config_entries, core from homeassistant.const import Platform from homeassistant.exceptions import ConfigEntryError, ConfigEntryNotReady +from homeassistant.helpers import device_registry as dr, entity_registry as er from homeassistant.helpers.aiohttp_client import async_get_clientsession from .const import CONF_DESTINATION, CONF_START, DOMAIN @@ -65,3 +66,51 @@ async def async_unload_entry( hass.data[DOMAIN].pop(entry.entry_id) return unload_ok + + +async def async_migrate_entry( + hass: core.HomeAssistant, config_entry: config_entries.ConfigEntry +) -> bool: + """Migrate config entry.""" + _LOGGER.debug("Migrating from version %s", config_entry.version) + + if config_entry.minor_version > 3: + # This means the user has downgraded from a future version + return False + + if config_entry.minor_version == 1: + # Remove wrongly registered devices and entries + new_unique_id = ( + f"{config_entry.data[CONF_START]} {config_entry.data[CONF_DESTINATION]}" + ) + entity_registry = er.async_get(hass) + device_registry = dr.async_get(hass) + device_entries = dr.async_entries_for_config_entry( + device_registry, config_entry_id=config_entry.entry_id + ) + for dev in device_entries: + device_registry.async_remove_device(dev.id) + + entity_id = entity_registry.async_get_entity_id( + Platform.SENSOR, DOMAIN, "None_departure" + ) + if entity_id: + entity_registry.async_update_entity( + entity_id=entity_id, + new_unique_id=f"{new_unique_id}_departure", + ) + _LOGGER.debug( + "Faulty entity with unique_id 'None_departure' migrated to new unique_id '%s'", + f"{new_unique_id}_departure", + ) + + # Set a valid unique id for config entries + config_entry.unique_id = new_unique_id + config_entry.minor_version = 2 + hass.config_entries.async_update_entry(config_entry) + + _LOGGER.debug( + "Migration to minor version %s successful", config_entry.minor_version + ) + + return True diff --git a/homeassistant/components/swiss_public_transport/config_flow.py b/homeassistant/components/swiss_public_transport/config_flow.py index 63eca1efe96..ceb6f46806d 100644 --- a/homeassistant/components/swiss_public_transport/config_flow.py +++ b/homeassistant/components/swiss_public_transport/config_flow.py @@ -31,6 +31,7 @@ class SwissPublicTransportConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): """Swiss public transport config flow.""" VERSION = 1 + MINOR_VERSION = 2 async def async_step_user( self, user_input: dict[str, Any] | None = None @@ -59,6 +60,9 @@ class SwissPublicTransportConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): _LOGGER.exception("Unknown error") errors["base"] = "unknown" else: + await self.async_set_unique_id( + f"{user_input[CONF_START]} {user_input[CONF_DESTINATION]}" + ) return self.async_create_entry( title=f"{user_input[CONF_START]} {user_input[CONF_DESTINATION]}", data=user_input, @@ -98,6 +102,9 @@ class SwissPublicTransportConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): ) return self.async_abort(reason="unknown") + await self.async_set_unique_id( + f"{import_input[CONF_START]} {import_input[CONF_DESTINATION]}" + ) return self.async_create_entry( title=import_input[CONF_NAME], data=import_input, diff --git a/tests/components/swiss_public_transport/test_init.py b/tests/components/swiss_public_transport/test_init.py new file mode 100644 index 00000000000..f2b4e41ed71 --- /dev/null +++ b/tests/components/swiss_public_transport/test_init.py @@ -0,0 +1,85 @@ +"""Test the swiss_public_transport config flow.""" +from unittest.mock import AsyncMock, patch + +from homeassistant.components.swiss_public_transport.const import ( + CONF_DESTINATION, + CONF_START, + DOMAIN, +) +from homeassistant.const import Platform +from homeassistant.core import HomeAssistant +from homeassistant.helpers import entity_registry as er + +from tests.common import MockConfigEntry + +MOCK_DATA_STEP = { + CONF_START: "test_start", + CONF_DESTINATION: "test_destination", +} + +CONNECTIONS = [ + { + "departure": "2024-01-06T18:03:00+0100", + "number": 0, + "platform": 0, + "transfers": 0, + "duration": "10", + "delay": 0, + }, + { + "departure": "2024-01-06T18:04:00+0100", + "number": 1, + "platform": 1, + "transfers": 0, + "duration": "10", + "delay": 0, + }, + { + "departure": "2024-01-06T18:05:00+0100", + "number": 2, + "platform": 2, + "transfers": 0, + "duration": "10", + "delay": 0, + }, +] + + +async def test_migration_1_to_2( + hass: HomeAssistant, entity_registry: er.EntityRegistry +) -> None: + """Test successful setup.""" + + with patch( + "homeassistant.components.swiss_public_transport.OpendataTransport", + return_value=AsyncMock(), + ) as mock: + mock().connections = CONNECTIONS + + config_entry_faulty = MockConfigEntry( + domain=DOMAIN, + data=MOCK_DATA_STEP, + title="MIGRATION_TEST", + minor_version=1, + ) + config_entry_faulty.add_to_hass(hass) + + # Setup the config entry + await hass.config_entries.async_setup(config_entry_faulty.entry_id) + await hass.async_block_till_done() + assert entity_registry.async_is_registered( + entity_registry.entities.get_entity_id( + (Platform.SENSOR, DOMAIN, "test_start test_destination_departure") + ) + ) + + # Check change in config entry + assert config_entry_faulty.minor_version == 2 + assert config_entry_faulty.unique_id == "test_start test_destination" + + # Check "None" is gone + assert not entity_registry.async_is_registered( + entity_registry.entities.get_entity_id( + (Platform.SENSOR, DOMAIN, "None_departure") + ) + )