Ensure AirVisual Pro migration includes device and entity customizations (#84798)

* Ensure AirVisual Pro migration includes device and entity customizations

* Update homeassistant/components/airvisual/__init__.py

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>

* Code review

* Fix tests

* Fix tests FOR REAL

Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
Aaron Bach 2022-12-30 14:47:41 -07:00 committed by GitHub
parent 60de2a82c7
commit 34dc47ad10
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 108 additions and 77 deletions

View File

@ -32,6 +32,7 @@ from homeassistant.helpers import (
aiohttp_client,
config_validation as cv,
device_registry as dr,
entity_registry as er,
)
from homeassistant.helpers.entity import EntityDescription
from homeassistant.helpers.issue_registry import IssueSeverity, async_create_issue
@ -116,36 +117,6 @@ def async_get_geography_id(geography_dict: Mapping[str, Any]) -> str:
)
@callback
def async_get_pro_config_entry_by_ip_address(
hass: HomeAssistant, ip_address: str
) -> ConfigEntry:
"""Get the Pro config entry related to an IP address."""
[config_entry] = [
entry
for entry in hass.config_entries.async_entries(DOMAIN_AIRVISUAL_PRO)
if entry.data[CONF_IP_ADDRESS] == ip_address
]
return config_entry
@callback
def async_get_pro_device_by_config_entry(
hass: HomeAssistant, config_entry: ConfigEntry
) -> dr.DeviceEntry:
"""Get the Pro device entry related to a config entry.
Note that a Pro config entry can only contain a single device.
"""
device_registry = dr.async_get(hass)
[device_entry] = [
device_entry
for device_entry in device_registry.devices.values()
if config_entry.entry_id in device_entry.config_entries
]
return device_entry
@callback
def async_sync_geo_coordinator_update_intervals(
hass: HomeAssistant, api_key: str
@ -305,14 +276,31 @@ async def async_migrate_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
version = 3
if entry.data[CONF_INTEGRATION_TYPE] == INTEGRATION_TYPE_NODE_PRO:
device_registry = dr.async_get(hass)
entity_registry = er.async_get(hass)
ip_address = entry.data[CONF_IP_ADDRESS]
# Get the existing Pro device entry before it is removed by the migration:
old_device_entry = async_get_pro_device_by_config_entry(hass, entry)
# Store the existing Pro device before the migration removes it:
old_device_entry = next(
entry
for entry in dr.async_entries_for_config_entry(
device_registry, entry.entry_id
)
)
# Store the existing Pro entity entries (mapped by unique ID) before the
# migration removes it:
old_entity_entries: dict[str, er.RegistryEntry] = {
entry.unique_id: entry
for entry in er.async_entries_for_device(
entity_registry, old_device_entry.id, include_disabled_entities=True
)
}
# Remove this config entry and create a new one under the `airvisual_pro`
# domain:
new_entry_data = {**entry.data}
new_entry_data.pop(CONF_INTEGRATION_TYPE)
tasks = [
hass.config_entries.async_remove(entry.entry_id),
hass.config_entries.flow.async_init(
@ -323,18 +311,52 @@ async def async_migrate_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
]
await asyncio.gather(*tasks)
# After the migration has occurred, grab the new config and device entries
# (now under the `airvisual_pro` domain):
new_config_entry = next(
entry
for entry in hass.config_entries.async_entries(DOMAIN_AIRVISUAL_PRO)
if entry.data[CONF_IP_ADDRESS] == ip_address
)
new_device_entry = next(
entry
for entry in dr.async_entries_for_config_entry(
device_registry, new_config_entry.entry_id
)
)
# Update the new device entry with any customizations from the old one:
device_registry.async_update_device(
new_device_entry.id,
area_id=old_device_entry.area_id,
disabled_by=old_device_entry.disabled_by,
name_by_user=old_device_entry.name_by_user,
)
# Update the new entity entries with any customizations from the old ones:
for new_entity_entry in er.async_entries_for_device(
entity_registry, new_device_entry.id, include_disabled_entities=True
):
if old_entity_entry := old_entity_entries.get(
new_entity_entry.unique_id
):
entity_registry.async_update_entity(
new_entity_entry.entity_id,
area_id=old_entity_entry.area_id,
device_class=old_entity_entry.device_class,
disabled_by=old_entity_entry.disabled_by,
hidden_by=old_entity_entry.hidden_by,
icon=old_entity_entry.icon,
name=old_entity_entry.name,
new_entity_id=old_entity_entry.entity_id,
unit_of_measurement=old_entity_entry.unit_of_measurement,
)
# If any automations are using the old device ID, create a Repairs issues
# with instructions on how to update it:
if device_automations := automation.automations_with_device(
hass, old_device_entry.id
):
new_config_entry = async_get_pro_config_entry_by_ip_address(
hass, ip_address
)
new_device_entry = async_get_pro_device_by_config_entry(
hass, new_config_entry
)
async_create_issue(
hass,
DOMAIN,

View File

@ -1,6 +1,6 @@
"""Define test fixtures for AirVisual."""
import json
from unittest.mock import patch
from unittest.mock import AsyncMock, Mock, patch
import pytest
@ -56,17 +56,27 @@ def data_fixture():
return json.loads(load_fixture("data.json", "airvisual"))
@pytest.fixture(name="pro_data", scope="session")
def pro_data_fixture():
"""Define an update coordinator data example for the Pro."""
return json.loads(load_fixture("data.json", "airvisual_pro"))
@pytest.fixture(name="pro")
def pro_fixture(pro_data):
"""Define a mocked NodeSamba object."""
return Mock(
async_connect=AsyncMock(),
async_disconnect=AsyncMock(),
async_get_latest_measurements=AsyncMock(return_value=pro_data),
)
@pytest.fixture(name="setup_airvisual")
async def setup_airvisual_fixture(hass, config, data):
"""Define a fixture to set up AirVisual."""
with patch("pyairvisual.air_quality.AirQuality.city"), patch(
"pyairvisual.air_quality.AirQuality.nearest_city", return_value=data
), patch("pyairvisual.node.NodeSamba.async_connect"), patch(
"pyairvisual.node.NodeSamba.async_get_latest_measurements"
), patch(
"pyairvisual.node.NodeSamba.async_disconnect"
), patch(
"homeassistant.components.airvisual.PLATFORMS", []
):
assert await async_setup_component(hass, DOMAIN, config)
await hass.async_block_till_done()

View File

@ -1,5 +1,5 @@
"""Define tests for the AirVisual config flow."""
from unittest.mock import Mock, patch
from unittest.mock import patch
from pyairvisual.cloud_api import (
InvalidKeyError,
@ -21,6 +21,7 @@ from homeassistant.components.airvisual import (
INTEGRATION_TYPE_GEOGRAPHY_NAME,
INTEGRATION_TYPE_NODE_PRO,
)
from homeassistant.components.airvisual_pro import DOMAIN as AIRVISUAL_PRO_DOMAIN
from homeassistant.config_entries import SOURCE_REAUTH, SOURCE_USER
from homeassistant.const import (
CONF_API_KEY,
@ -31,8 +32,7 @@ from homeassistant.const import (
CONF_SHOW_ON_MAP,
CONF_STATE,
)
from homeassistant.helpers import issue_registry as ir
from homeassistant.setup import async_setup_component
from homeassistant.helpers import device_registry as dr, issue_registry as ir
from tests.common import MockConfigEntry
@ -169,42 +169,41 @@ async def test_migration_1_2(hass, config, config_entry, setup_airvisual, unique
}
@pytest.mark.parametrize(
"config,config_entry_version,unique_id",
[
(
{
CONF_IP_ADDRESS: "192.168.1.100",
CONF_PASSWORD: "abcde12345",
CONF_INTEGRATION_TYPE: INTEGRATION_TYPE_NODE_PRO,
},
2,
"192.16.1.100",
)
],
)
async def test_migration_2_3(hass, config, config_entry, unique_id):
async def test_migration_2_3(hass, pro):
"""Test migrating from version 2 to 3."""
old_pro_entry = MockConfigEntry(
domain=DOMAIN,
unique_id="192.168.1.100",
data={
CONF_IP_ADDRESS: "192.168.1.100",
CONF_PASSWORD: "abcde12345",
CONF_INTEGRATION_TYPE: INTEGRATION_TYPE_NODE_PRO,
},
version=2,
)
old_pro_entry.add_to_hass(hass)
device_registry = dr.async_get(hass)
device_registry.async_get_or_create(
name="192.168.1.100",
config_entry_id=old_pro_entry.entry_id,
identifiers={(DOMAIN, "ABCDE12345")},
)
with patch(
"homeassistant.components.airvisual.automation.automations_with_device",
return_value=["automation.test_automation"],
), patch(
"homeassistant.components.airvisual.async_get_pro_config_entry_by_ip_address",
return_value=MockConfigEntry(
domain="airvisual_pro",
unique_id="192.168.1.100",
data={CONF_IP_ADDRESS: "192.168.1.100", CONF_PASSWORD: "abcde12345"},
version=3,
),
"homeassistant.components.airvisual_pro.NodeSamba", return_value=pro
), patch(
"homeassistant.components.airvisual.async_get_pro_device_by_config_entry",
return_value=Mock(id="abcde12345"),
"homeassistant.components.airvisual_pro.config_flow.NodeSamba", return_value=pro
):
assert await async_setup_component(hass, DOMAIN, config)
await hass.config_entries.async_setup(old_pro_entry.entry_id)
await hass.async_block_till_done()
airvisual_config_entries = hass.config_entries.async_entries(DOMAIN)
assert len(airvisual_config_entries) == 0
for domain, entry_count in ((DOMAIN, 0), (AIRVISUAL_PRO_DOMAIN, 1)):
entries = hass.config_entries.async_entries(domain)
assert len(entries) == entry_count
issue_registry = ir.async_get(hass)
assert len(issue_registry.issues) == 1