Ensure AirVisual Pro migration includes device and entity customizations (#84798)

* Ensure AirVisual Pro migration includes device and entity customizations

* Update homeassistant/components/airvisual/__init__.py

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

* Code review

* Fix tests

* Fix tests FOR REAL

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
Aaron Bach 2022-12-30 14:47:41 -07:00 committed by GitHub
parent 60de2a82c7
commit 34dc47ad10
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 108 additions and 77 deletions

View File

@ -32,6 +32,7 @@ from homeassistant.helpers import (
aiohttp_client, aiohttp_client,
config_validation as cv, config_validation as cv,
device_registry as dr, device_registry as dr,
entity_registry as er,
) )
from homeassistant.helpers.entity import EntityDescription from homeassistant.helpers.entity import EntityDescription
from homeassistant.helpers.issue_registry import IssueSeverity, async_create_issue from homeassistant.helpers.issue_registry import IssueSeverity, async_create_issue
@ -116,36 +117,6 @@ def async_get_geography_id(geography_dict: Mapping[str, Any]) -> str:
) )
@callback
def async_get_pro_config_entry_by_ip_address(
hass: HomeAssistant, ip_address: str
) -> ConfigEntry:
"""Get the Pro config entry related to an IP address."""
[config_entry] = [
entry
for entry in hass.config_entries.async_entries(DOMAIN_AIRVISUAL_PRO)
if entry.data[CONF_IP_ADDRESS] == ip_address
]
return config_entry
@callback
def async_get_pro_device_by_config_entry(
hass: HomeAssistant, config_entry: ConfigEntry
) -> dr.DeviceEntry:
"""Get the Pro device entry related to a config entry.
Note that a Pro config entry can only contain a single device.
"""
device_registry = dr.async_get(hass)
[device_entry] = [
device_entry
for device_entry in device_registry.devices.values()
if config_entry.entry_id in device_entry.config_entries
]
return device_entry
@callback @callback
def async_sync_geo_coordinator_update_intervals( def async_sync_geo_coordinator_update_intervals(
hass: HomeAssistant, api_key: str hass: HomeAssistant, api_key: str
@ -305,14 +276,31 @@ async def async_migrate_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
version = 3 version = 3
if entry.data[CONF_INTEGRATION_TYPE] == INTEGRATION_TYPE_NODE_PRO: if entry.data[CONF_INTEGRATION_TYPE] == INTEGRATION_TYPE_NODE_PRO:
device_registry = dr.async_get(hass)
entity_registry = er.async_get(hass)
ip_address = entry.data[CONF_IP_ADDRESS] ip_address = entry.data[CONF_IP_ADDRESS]
# Get the existing Pro device entry before it is removed by the migration: # Store the existing Pro device before the migration removes it:
old_device_entry = async_get_pro_device_by_config_entry(hass, entry) old_device_entry = next(
entry
for entry in dr.async_entries_for_config_entry(
device_registry, entry.entry_id
)
)
# Store the existing Pro entity entries (mapped by unique ID) before the
# migration removes it:
old_entity_entries: dict[str, er.RegistryEntry] = {
entry.unique_id: entry
for entry in er.async_entries_for_device(
entity_registry, old_device_entry.id, include_disabled_entities=True
)
}
# Remove this config entry and create a new one under the `airvisual_pro`
# domain:
new_entry_data = {**entry.data} new_entry_data = {**entry.data}
new_entry_data.pop(CONF_INTEGRATION_TYPE) new_entry_data.pop(CONF_INTEGRATION_TYPE)
tasks = [ tasks = [
hass.config_entries.async_remove(entry.entry_id), hass.config_entries.async_remove(entry.entry_id),
hass.config_entries.flow.async_init( hass.config_entries.flow.async_init(
@ -323,18 +311,52 @@ async def async_migrate_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
] ]
await asyncio.gather(*tasks) await asyncio.gather(*tasks)
# After the migration has occurred, grab the new config and device entries
# (now under the `airvisual_pro` domain):
new_config_entry = next(
entry
for entry in hass.config_entries.async_entries(DOMAIN_AIRVISUAL_PRO)
if entry.data[CONF_IP_ADDRESS] == ip_address
)
new_device_entry = next(
entry
for entry in dr.async_entries_for_config_entry(
device_registry, new_config_entry.entry_id
)
)
# Update the new device entry with any customizations from the old one:
device_registry.async_update_device(
new_device_entry.id,
area_id=old_device_entry.area_id,
disabled_by=old_device_entry.disabled_by,
name_by_user=old_device_entry.name_by_user,
)
# Update the new entity entries with any customizations from the old ones:
for new_entity_entry in er.async_entries_for_device(
entity_registry, new_device_entry.id, include_disabled_entities=True
):
if old_entity_entry := old_entity_entries.get(
new_entity_entry.unique_id
):
entity_registry.async_update_entity(
new_entity_entry.entity_id,
area_id=old_entity_entry.area_id,
device_class=old_entity_entry.device_class,
disabled_by=old_entity_entry.disabled_by,
hidden_by=old_entity_entry.hidden_by,
icon=old_entity_entry.icon,
name=old_entity_entry.name,
new_entity_id=old_entity_entry.entity_id,
unit_of_measurement=old_entity_entry.unit_of_measurement,
)
# If any automations are using the old device ID, create a Repairs issues # If any automations are using the old device ID, create a Repairs issues
# with instructions on how to update it: # with instructions on how to update it:
if device_automations := automation.automations_with_device( if device_automations := automation.automations_with_device(
hass, old_device_entry.id hass, old_device_entry.id
): ):
new_config_entry = async_get_pro_config_entry_by_ip_address(
hass, ip_address
)
new_device_entry = async_get_pro_device_by_config_entry(
hass, new_config_entry
)
async_create_issue( async_create_issue(
hass, hass,
DOMAIN, DOMAIN,

View File

@ -1,6 +1,6 @@
"""Define test fixtures for AirVisual.""" """Define test fixtures for AirVisual."""
import json import json
from unittest.mock import patch from unittest.mock import AsyncMock, Mock, patch
import pytest import pytest
@ -56,17 +56,27 @@ def data_fixture():
return json.loads(load_fixture("data.json", "airvisual")) return json.loads(load_fixture("data.json", "airvisual"))
@pytest.fixture(name="pro_data", scope="session")
def pro_data_fixture():
"""Define an update coordinator data example for the Pro."""
return json.loads(load_fixture("data.json", "airvisual_pro"))
@pytest.fixture(name="pro")
def pro_fixture(pro_data):
"""Define a mocked NodeSamba object."""
return Mock(
async_connect=AsyncMock(),
async_disconnect=AsyncMock(),
async_get_latest_measurements=AsyncMock(return_value=pro_data),
)
@pytest.fixture(name="setup_airvisual") @pytest.fixture(name="setup_airvisual")
async def setup_airvisual_fixture(hass, config, data): async def setup_airvisual_fixture(hass, config, data):
"""Define a fixture to set up AirVisual.""" """Define a fixture to set up AirVisual."""
with patch("pyairvisual.air_quality.AirQuality.city"), patch( with patch("pyairvisual.air_quality.AirQuality.city"), patch(
"pyairvisual.air_quality.AirQuality.nearest_city", return_value=data "pyairvisual.air_quality.AirQuality.nearest_city", return_value=data
), patch("pyairvisual.node.NodeSamba.async_connect"), patch(
"pyairvisual.node.NodeSamba.async_get_latest_measurements"
), patch(
"pyairvisual.node.NodeSamba.async_disconnect"
), patch(
"homeassistant.components.airvisual.PLATFORMS", []
): ):
assert await async_setup_component(hass, DOMAIN, config) assert await async_setup_component(hass, DOMAIN, config)
await hass.async_block_till_done() await hass.async_block_till_done()

View File

@ -1,5 +1,5 @@
"""Define tests for the AirVisual config flow.""" """Define tests for the AirVisual config flow."""
from unittest.mock import Mock, patch from unittest.mock import patch
from pyairvisual.cloud_api import ( from pyairvisual.cloud_api import (
InvalidKeyError, InvalidKeyError,
@ -21,6 +21,7 @@ from homeassistant.components.airvisual import (
INTEGRATION_TYPE_GEOGRAPHY_NAME, INTEGRATION_TYPE_GEOGRAPHY_NAME,
INTEGRATION_TYPE_NODE_PRO, INTEGRATION_TYPE_NODE_PRO,
) )
from homeassistant.components.airvisual_pro import DOMAIN as AIRVISUAL_PRO_DOMAIN
from homeassistant.config_entries import SOURCE_REAUTH, SOURCE_USER from homeassistant.config_entries import SOURCE_REAUTH, SOURCE_USER
from homeassistant.const import ( from homeassistant.const import (
CONF_API_KEY, CONF_API_KEY,
@ -31,8 +32,7 @@ from homeassistant.const import (
CONF_SHOW_ON_MAP, CONF_SHOW_ON_MAP,
CONF_STATE, CONF_STATE,
) )
from homeassistant.helpers import issue_registry as ir from homeassistant.helpers import device_registry as dr, issue_registry as ir
from homeassistant.setup import async_setup_component
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
@ -169,42 +169,41 @@ async def test_migration_1_2(hass, config, config_entry, setup_airvisual, unique
} }
@pytest.mark.parametrize( async def test_migration_2_3(hass, pro):
"config,config_entry_version,unique_id",
[
(
{
CONF_IP_ADDRESS: "192.168.1.100",
CONF_PASSWORD: "abcde12345",
CONF_INTEGRATION_TYPE: INTEGRATION_TYPE_NODE_PRO,
},
2,
"192.16.1.100",
)
],
)
async def test_migration_2_3(hass, config, config_entry, unique_id):
"""Test migrating from version 2 to 3.""" """Test migrating from version 2 to 3."""
old_pro_entry = MockConfigEntry(
domain=DOMAIN,
unique_id="192.168.1.100",
data={
CONF_IP_ADDRESS: "192.168.1.100",
CONF_PASSWORD: "abcde12345",
CONF_INTEGRATION_TYPE: INTEGRATION_TYPE_NODE_PRO,
},
version=2,
)
old_pro_entry.add_to_hass(hass)
device_registry = dr.async_get(hass)
device_registry.async_get_or_create(
name="192.168.1.100",
config_entry_id=old_pro_entry.entry_id,
identifiers={(DOMAIN, "ABCDE12345")},
)
with patch( with patch(
"homeassistant.components.airvisual.automation.automations_with_device", "homeassistant.components.airvisual.automation.automations_with_device",
return_value=["automation.test_automation"], return_value=["automation.test_automation"],
), patch( ), patch(
"homeassistant.components.airvisual.async_get_pro_config_entry_by_ip_address", "homeassistant.components.airvisual_pro.NodeSamba", return_value=pro
return_value=MockConfigEntry(
domain="airvisual_pro",
unique_id="192.168.1.100",
data={CONF_IP_ADDRESS: "192.168.1.100", CONF_PASSWORD: "abcde12345"},
version=3,
),
), patch( ), patch(
"homeassistant.components.airvisual.async_get_pro_device_by_config_entry", "homeassistant.components.airvisual_pro.config_flow.NodeSamba", return_value=pro
return_value=Mock(id="abcde12345"),
): ):
assert await async_setup_component(hass, DOMAIN, config) await hass.config_entries.async_setup(old_pro_entry.entry_id)
await hass.async_block_till_done() await hass.async_block_till_done()
airvisual_config_entries = hass.config_entries.async_entries(DOMAIN) for domain, entry_count in ((DOMAIN, 0), (AIRVISUAL_PRO_DOMAIN, 1)):
assert len(airvisual_config_entries) == 0 entries = hass.config_entries.async_entries(domain)
assert len(entries) == entry_count
issue_registry = ir.async_get(hass) issue_registry = ir.async_get(hass)
assert len(issue_registry.issues) == 1 assert len(issue_registry.issues) == 1