diff --git a/homeassistant/components/airvisual/__init__.py b/homeassistant/components/airvisual/__init__.py index c9d77226643..32c2d71292f 100644 --- a/homeassistant/components/airvisual/__init__.py +++ b/homeassistant/components/airvisual/__init__.py @@ -32,6 +32,7 @@ from homeassistant.helpers import ( aiohttp_client, config_validation as cv, device_registry as dr, + entity_registry as er, ) from homeassistant.helpers.entity import EntityDescription from homeassistant.helpers.issue_registry import IssueSeverity, async_create_issue @@ -116,36 +117,6 @@ def async_get_geography_id(geography_dict: Mapping[str, Any]) -> str: ) -@callback -def async_get_pro_config_entry_by_ip_address( - hass: HomeAssistant, ip_address: str -) -> ConfigEntry: - """Get the Pro config entry related to an IP address.""" - [config_entry] = [ - entry - for entry in hass.config_entries.async_entries(DOMAIN_AIRVISUAL_PRO) - if entry.data[CONF_IP_ADDRESS] == ip_address - ] - return config_entry - - -@callback -def async_get_pro_device_by_config_entry( - hass: HomeAssistant, config_entry: ConfigEntry -) -> dr.DeviceEntry: - """Get the Pro device entry related to a config entry. - - Note that a Pro config entry can only contain a single device. - """ - device_registry = dr.async_get(hass) - [device_entry] = [ - device_entry - for device_entry in device_registry.devices.values() - if config_entry.entry_id in device_entry.config_entries - ] - return device_entry - - @callback def async_sync_geo_coordinator_update_intervals( hass: HomeAssistant, api_key: str @@ -305,14 +276,31 @@ async def async_migrate_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: version = 3 if entry.data[CONF_INTEGRATION_TYPE] == INTEGRATION_TYPE_NODE_PRO: + device_registry = dr.async_get(hass) + entity_registry = er.async_get(hass) ip_address = entry.data[CONF_IP_ADDRESS] - # Get the existing Pro device entry before it is removed by the migration: - old_device_entry = async_get_pro_device_by_config_entry(hass, entry) + # Store the existing Pro device before the migration removes it: + old_device_entry = next( + entry + for entry in dr.async_entries_for_config_entry( + device_registry, entry.entry_id + ) + ) + # Store the existing Pro entity entries (mapped by unique ID) before the + # migration removes it: + old_entity_entries: dict[str, er.RegistryEntry] = { + entry.unique_id: entry + for entry in er.async_entries_for_device( + entity_registry, old_device_entry.id, include_disabled_entities=True + ) + } + + # Remove this config entry and create a new one under the `airvisual_pro` + # domain: new_entry_data = {**entry.data} new_entry_data.pop(CONF_INTEGRATION_TYPE) - tasks = [ hass.config_entries.async_remove(entry.entry_id), hass.config_entries.flow.async_init( @@ -323,18 +311,52 @@ async def async_migrate_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: ] await asyncio.gather(*tasks) + # After the migration has occurred, grab the new config and device entries + # (now under the `airvisual_pro` domain): + new_config_entry = next( + entry + for entry in hass.config_entries.async_entries(DOMAIN_AIRVISUAL_PRO) + if entry.data[CONF_IP_ADDRESS] == ip_address + ) + new_device_entry = next( + entry + for entry in dr.async_entries_for_config_entry( + device_registry, new_config_entry.entry_id + ) + ) + + # Update the new device entry with any customizations from the old one: + device_registry.async_update_device( + new_device_entry.id, + area_id=old_device_entry.area_id, + disabled_by=old_device_entry.disabled_by, + name_by_user=old_device_entry.name_by_user, + ) + + # Update the new entity entries with any customizations from the old ones: + for new_entity_entry in er.async_entries_for_device( + entity_registry, new_device_entry.id, include_disabled_entities=True + ): + if old_entity_entry := old_entity_entries.get( + new_entity_entry.unique_id + ): + entity_registry.async_update_entity( + new_entity_entry.entity_id, + area_id=old_entity_entry.area_id, + device_class=old_entity_entry.device_class, + disabled_by=old_entity_entry.disabled_by, + hidden_by=old_entity_entry.hidden_by, + icon=old_entity_entry.icon, + name=old_entity_entry.name, + new_entity_id=old_entity_entry.entity_id, + unit_of_measurement=old_entity_entry.unit_of_measurement, + ) + # If any automations are using the old device ID, create a Repairs issues # with instructions on how to update it: if device_automations := automation.automations_with_device( hass, old_device_entry.id ): - new_config_entry = async_get_pro_config_entry_by_ip_address( - hass, ip_address - ) - new_device_entry = async_get_pro_device_by_config_entry( - hass, new_config_entry - ) - async_create_issue( hass, DOMAIN, diff --git a/tests/components/airvisual/conftest.py b/tests/components/airvisual/conftest.py index 3e83b41a5af..8ef060c3116 100644 --- a/tests/components/airvisual/conftest.py +++ b/tests/components/airvisual/conftest.py @@ -1,6 +1,6 @@ """Define test fixtures for AirVisual.""" import json -from unittest.mock import patch +from unittest.mock import AsyncMock, Mock, patch import pytest @@ -56,17 +56,27 @@ def data_fixture(): return json.loads(load_fixture("data.json", "airvisual")) +@pytest.fixture(name="pro_data", scope="session") +def pro_data_fixture(): + """Define an update coordinator data example for the Pro.""" + return json.loads(load_fixture("data.json", "airvisual_pro")) + + +@pytest.fixture(name="pro") +def pro_fixture(pro_data): + """Define a mocked NodeSamba object.""" + return Mock( + async_connect=AsyncMock(), + async_disconnect=AsyncMock(), + async_get_latest_measurements=AsyncMock(return_value=pro_data), + ) + + @pytest.fixture(name="setup_airvisual") async def setup_airvisual_fixture(hass, config, data): """Define a fixture to set up AirVisual.""" with patch("pyairvisual.air_quality.AirQuality.city"), patch( "pyairvisual.air_quality.AirQuality.nearest_city", return_value=data - ), patch("pyairvisual.node.NodeSamba.async_connect"), patch( - "pyairvisual.node.NodeSamba.async_get_latest_measurements" - ), patch( - "pyairvisual.node.NodeSamba.async_disconnect" - ), patch( - "homeassistant.components.airvisual.PLATFORMS", [] ): assert await async_setup_component(hass, DOMAIN, config) await hass.async_block_till_done() diff --git a/tests/components/airvisual/test_config_flow.py b/tests/components/airvisual/test_config_flow.py index d322726340a..7bad9af1002 100644 --- a/tests/components/airvisual/test_config_flow.py +++ b/tests/components/airvisual/test_config_flow.py @@ -1,5 +1,5 @@ """Define tests for the AirVisual config flow.""" -from unittest.mock import Mock, patch +from unittest.mock import patch from pyairvisual.cloud_api import ( InvalidKeyError, @@ -21,6 +21,7 @@ from homeassistant.components.airvisual import ( INTEGRATION_TYPE_GEOGRAPHY_NAME, INTEGRATION_TYPE_NODE_PRO, ) +from homeassistant.components.airvisual_pro import DOMAIN as AIRVISUAL_PRO_DOMAIN from homeassistant.config_entries import SOURCE_REAUTH, SOURCE_USER from homeassistant.const import ( CONF_API_KEY, @@ -31,8 +32,7 @@ from homeassistant.const import ( CONF_SHOW_ON_MAP, CONF_STATE, ) -from homeassistant.helpers import issue_registry as ir -from homeassistant.setup import async_setup_component +from homeassistant.helpers import device_registry as dr, issue_registry as ir from tests.common import MockConfigEntry @@ -169,42 +169,41 @@ async def test_migration_1_2(hass, config, config_entry, setup_airvisual, unique } -@pytest.mark.parametrize( - "config,config_entry_version,unique_id", - [ - ( - { - CONF_IP_ADDRESS: "192.168.1.100", - CONF_PASSWORD: "abcde12345", - CONF_INTEGRATION_TYPE: INTEGRATION_TYPE_NODE_PRO, - }, - 2, - "192.16.1.100", - ) - ], -) -async def test_migration_2_3(hass, config, config_entry, unique_id): +async def test_migration_2_3(hass, pro): """Test migrating from version 2 to 3.""" + old_pro_entry = MockConfigEntry( + domain=DOMAIN, + unique_id="192.168.1.100", + data={ + CONF_IP_ADDRESS: "192.168.1.100", + CONF_PASSWORD: "abcde12345", + CONF_INTEGRATION_TYPE: INTEGRATION_TYPE_NODE_PRO, + }, + version=2, + ) + old_pro_entry.add_to_hass(hass) + + device_registry = dr.async_get(hass) + device_registry.async_get_or_create( + name="192.168.1.100", + config_entry_id=old_pro_entry.entry_id, + identifiers={(DOMAIN, "ABCDE12345")}, + ) + with patch( "homeassistant.components.airvisual.automation.automations_with_device", return_value=["automation.test_automation"], ), patch( - "homeassistant.components.airvisual.async_get_pro_config_entry_by_ip_address", - return_value=MockConfigEntry( - domain="airvisual_pro", - unique_id="192.168.1.100", - data={CONF_IP_ADDRESS: "192.168.1.100", CONF_PASSWORD: "abcde12345"}, - version=3, - ), + "homeassistant.components.airvisual_pro.NodeSamba", return_value=pro ), patch( - "homeassistant.components.airvisual.async_get_pro_device_by_config_entry", - return_value=Mock(id="abcde12345"), + "homeassistant.components.airvisual_pro.config_flow.NodeSamba", return_value=pro ): - assert await async_setup_component(hass, DOMAIN, config) + await hass.config_entries.async_setup(old_pro_entry.entry_id) await hass.async_block_till_done() - airvisual_config_entries = hass.config_entries.async_entries(DOMAIN) - assert len(airvisual_config_entries) == 0 + for domain, entry_count in ((DOMAIN, 0), (AIRVISUAL_PRO_DOMAIN, 1)): + entries = hass.config_entries.async_entries(domain) + assert len(entries) == entry_count issue_registry = ir.async_get(hass) assert len(issue_registry.issues) == 1