mirror of
https://github.com/home-assistant/core.git
synced 2025-07-19 11:17:21 +00:00
Ensure AirVisual Pro migration includes device and entity customizations (#84798)
* Ensure AirVisual Pro migration includes device and entity customizations * Update homeassistant/components/airvisual/__init__.py Co-authored-by: Martin Hjelmare <marhje52@gmail.com> * Code review * Fix tests * Fix tests FOR REAL Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
parent
60de2a82c7
commit
34dc47ad10
@ -32,6 +32,7 @@ from homeassistant.helpers import (
|
||||
aiohttp_client,
|
||||
config_validation as cv,
|
||||
device_registry as dr,
|
||||
entity_registry as er,
|
||||
)
|
||||
from homeassistant.helpers.entity import EntityDescription
|
||||
from homeassistant.helpers.issue_registry import IssueSeverity, async_create_issue
|
||||
@ -116,36 +117,6 @@ def async_get_geography_id(geography_dict: Mapping[str, Any]) -> str:
|
||||
)
|
||||
|
||||
|
||||
@callback
|
||||
def async_get_pro_config_entry_by_ip_address(
|
||||
hass: HomeAssistant, ip_address: str
|
||||
) -> ConfigEntry:
|
||||
"""Get the Pro config entry related to an IP address."""
|
||||
[config_entry] = [
|
||||
entry
|
||||
for entry in hass.config_entries.async_entries(DOMAIN_AIRVISUAL_PRO)
|
||||
if entry.data[CONF_IP_ADDRESS] == ip_address
|
||||
]
|
||||
return config_entry
|
||||
|
||||
|
||||
@callback
|
||||
def async_get_pro_device_by_config_entry(
|
||||
hass: HomeAssistant, config_entry: ConfigEntry
|
||||
) -> dr.DeviceEntry:
|
||||
"""Get the Pro device entry related to a config entry.
|
||||
|
||||
Note that a Pro config entry can only contain a single device.
|
||||
"""
|
||||
device_registry = dr.async_get(hass)
|
||||
[device_entry] = [
|
||||
device_entry
|
||||
for device_entry in device_registry.devices.values()
|
||||
if config_entry.entry_id in device_entry.config_entries
|
||||
]
|
||||
return device_entry
|
||||
|
||||
|
||||
@callback
|
||||
def async_sync_geo_coordinator_update_intervals(
|
||||
hass: HomeAssistant, api_key: str
|
||||
@ -305,14 +276,31 @@ async def async_migrate_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
version = 3
|
||||
|
||||
if entry.data[CONF_INTEGRATION_TYPE] == INTEGRATION_TYPE_NODE_PRO:
|
||||
device_registry = dr.async_get(hass)
|
||||
entity_registry = er.async_get(hass)
|
||||
ip_address = entry.data[CONF_IP_ADDRESS]
|
||||
|
||||
# Get the existing Pro device entry before it is removed by the migration:
|
||||
old_device_entry = async_get_pro_device_by_config_entry(hass, entry)
|
||||
# Store the existing Pro device before the migration removes it:
|
||||
old_device_entry = next(
|
||||
entry
|
||||
for entry in dr.async_entries_for_config_entry(
|
||||
device_registry, entry.entry_id
|
||||
)
|
||||
)
|
||||
|
||||
# Store the existing Pro entity entries (mapped by unique ID) before the
|
||||
# migration removes it:
|
||||
old_entity_entries: dict[str, er.RegistryEntry] = {
|
||||
entry.unique_id: entry
|
||||
for entry in er.async_entries_for_device(
|
||||
entity_registry, old_device_entry.id, include_disabled_entities=True
|
||||
)
|
||||
}
|
||||
|
||||
# Remove this config entry and create a new one under the `airvisual_pro`
|
||||
# domain:
|
||||
new_entry_data = {**entry.data}
|
||||
new_entry_data.pop(CONF_INTEGRATION_TYPE)
|
||||
|
||||
tasks = [
|
||||
hass.config_entries.async_remove(entry.entry_id),
|
||||
hass.config_entries.flow.async_init(
|
||||
@ -323,18 +311,52 @@ async def async_migrate_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
]
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
# After the migration has occurred, grab the new config and device entries
|
||||
# (now under the `airvisual_pro` domain):
|
||||
new_config_entry = next(
|
||||
entry
|
||||
for entry in hass.config_entries.async_entries(DOMAIN_AIRVISUAL_PRO)
|
||||
if entry.data[CONF_IP_ADDRESS] == ip_address
|
||||
)
|
||||
new_device_entry = next(
|
||||
entry
|
||||
for entry in dr.async_entries_for_config_entry(
|
||||
device_registry, new_config_entry.entry_id
|
||||
)
|
||||
)
|
||||
|
||||
# Update the new device entry with any customizations from the old one:
|
||||
device_registry.async_update_device(
|
||||
new_device_entry.id,
|
||||
area_id=old_device_entry.area_id,
|
||||
disabled_by=old_device_entry.disabled_by,
|
||||
name_by_user=old_device_entry.name_by_user,
|
||||
)
|
||||
|
||||
# Update the new entity entries with any customizations from the old ones:
|
||||
for new_entity_entry in er.async_entries_for_device(
|
||||
entity_registry, new_device_entry.id, include_disabled_entities=True
|
||||
):
|
||||
if old_entity_entry := old_entity_entries.get(
|
||||
new_entity_entry.unique_id
|
||||
):
|
||||
entity_registry.async_update_entity(
|
||||
new_entity_entry.entity_id,
|
||||
area_id=old_entity_entry.area_id,
|
||||
device_class=old_entity_entry.device_class,
|
||||
disabled_by=old_entity_entry.disabled_by,
|
||||
hidden_by=old_entity_entry.hidden_by,
|
||||
icon=old_entity_entry.icon,
|
||||
name=old_entity_entry.name,
|
||||
new_entity_id=old_entity_entry.entity_id,
|
||||
unit_of_measurement=old_entity_entry.unit_of_measurement,
|
||||
)
|
||||
|
||||
# If any automations are using the old device ID, create a Repairs issues
|
||||
# with instructions on how to update it:
|
||||
if device_automations := automation.automations_with_device(
|
||||
hass, old_device_entry.id
|
||||
):
|
||||
new_config_entry = async_get_pro_config_entry_by_ip_address(
|
||||
hass, ip_address
|
||||
)
|
||||
new_device_entry = async_get_pro_device_by_config_entry(
|
||||
hass, new_config_entry
|
||||
)
|
||||
|
||||
async_create_issue(
|
||||
hass,
|
||||
DOMAIN,
|
||||
|
@ -1,6 +1,6 @@
|
||||
"""Define test fixtures for AirVisual."""
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
@ -56,17 +56,27 @@ def data_fixture():
|
||||
return json.loads(load_fixture("data.json", "airvisual"))
|
||||
|
||||
|
||||
@pytest.fixture(name="pro_data", scope="session")
|
||||
def pro_data_fixture():
|
||||
"""Define an update coordinator data example for the Pro."""
|
||||
return json.loads(load_fixture("data.json", "airvisual_pro"))
|
||||
|
||||
|
||||
@pytest.fixture(name="pro")
|
||||
def pro_fixture(pro_data):
|
||||
"""Define a mocked NodeSamba object."""
|
||||
return Mock(
|
||||
async_connect=AsyncMock(),
|
||||
async_disconnect=AsyncMock(),
|
||||
async_get_latest_measurements=AsyncMock(return_value=pro_data),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(name="setup_airvisual")
|
||||
async def setup_airvisual_fixture(hass, config, data):
|
||||
"""Define a fixture to set up AirVisual."""
|
||||
with patch("pyairvisual.air_quality.AirQuality.city"), patch(
|
||||
"pyairvisual.air_quality.AirQuality.nearest_city", return_value=data
|
||||
), patch("pyairvisual.node.NodeSamba.async_connect"), patch(
|
||||
"pyairvisual.node.NodeSamba.async_get_latest_measurements"
|
||||
), patch(
|
||||
"pyairvisual.node.NodeSamba.async_disconnect"
|
||||
), patch(
|
||||
"homeassistant.components.airvisual.PLATFORMS", []
|
||||
):
|
||||
assert await async_setup_component(hass, DOMAIN, config)
|
||||
await hass.async_block_till_done()
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""Define tests for the AirVisual config flow."""
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import patch
|
||||
|
||||
from pyairvisual.cloud_api import (
|
||||
InvalidKeyError,
|
||||
@ -21,6 +21,7 @@ from homeassistant.components.airvisual import (
|
||||
INTEGRATION_TYPE_GEOGRAPHY_NAME,
|
||||
INTEGRATION_TYPE_NODE_PRO,
|
||||
)
|
||||
from homeassistant.components.airvisual_pro import DOMAIN as AIRVISUAL_PRO_DOMAIN
|
||||
from homeassistant.config_entries import SOURCE_REAUTH, SOURCE_USER
|
||||
from homeassistant.const import (
|
||||
CONF_API_KEY,
|
||||
@ -31,8 +32,7 @@ from homeassistant.const import (
|
||||
CONF_SHOW_ON_MAP,
|
||||
CONF_STATE,
|
||||
)
|
||||
from homeassistant.helpers import issue_registry as ir
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.helpers import device_registry as dr, issue_registry as ir
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
@ -169,42 +169,41 @@ async def test_migration_1_2(hass, config, config_entry, setup_airvisual, unique
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"config,config_entry_version,unique_id",
|
||||
[
|
||||
(
|
||||
{
|
||||
CONF_IP_ADDRESS: "192.168.1.100",
|
||||
CONF_PASSWORD: "abcde12345",
|
||||
CONF_INTEGRATION_TYPE: INTEGRATION_TYPE_NODE_PRO,
|
||||
},
|
||||
2,
|
||||
"192.16.1.100",
|
||||
)
|
||||
],
|
||||
)
|
||||
async def test_migration_2_3(hass, config, config_entry, unique_id):
|
||||
async def test_migration_2_3(hass, pro):
|
||||
"""Test migrating from version 2 to 3."""
|
||||
old_pro_entry = MockConfigEntry(
|
||||
domain=DOMAIN,
|
||||
unique_id="192.168.1.100",
|
||||
data={
|
||||
CONF_IP_ADDRESS: "192.168.1.100",
|
||||
CONF_PASSWORD: "abcde12345",
|
||||
CONF_INTEGRATION_TYPE: INTEGRATION_TYPE_NODE_PRO,
|
||||
},
|
||||
version=2,
|
||||
)
|
||||
old_pro_entry.add_to_hass(hass)
|
||||
|
||||
device_registry = dr.async_get(hass)
|
||||
device_registry.async_get_or_create(
|
||||
name="192.168.1.100",
|
||||
config_entry_id=old_pro_entry.entry_id,
|
||||
identifiers={(DOMAIN, "ABCDE12345")},
|
||||
)
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.airvisual.automation.automations_with_device",
|
||||
return_value=["automation.test_automation"],
|
||||
), patch(
|
||||
"homeassistant.components.airvisual.async_get_pro_config_entry_by_ip_address",
|
||||
return_value=MockConfigEntry(
|
||||
domain="airvisual_pro",
|
||||
unique_id="192.168.1.100",
|
||||
data={CONF_IP_ADDRESS: "192.168.1.100", CONF_PASSWORD: "abcde12345"},
|
||||
version=3,
|
||||
),
|
||||
"homeassistant.components.airvisual_pro.NodeSamba", return_value=pro
|
||||
), patch(
|
||||
"homeassistant.components.airvisual.async_get_pro_device_by_config_entry",
|
||||
return_value=Mock(id="abcde12345"),
|
||||
"homeassistant.components.airvisual_pro.config_flow.NodeSamba", return_value=pro
|
||||
):
|
||||
assert await async_setup_component(hass, DOMAIN, config)
|
||||
await hass.config_entries.async_setup(old_pro_entry.entry_id)
|
||||
await hass.async_block_till_done()
|
||||
|
||||
airvisual_config_entries = hass.config_entries.async_entries(DOMAIN)
|
||||
assert len(airvisual_config_entries) == 0
|
||||
for domain, entry_count in ((DOMAIN, 0), (AIRVISUAL_PRO_DOMAIN, 1)):
|
||||
entries = hass.config_entries.async_entries(domain)
|
||||
assert len(entries) == entry_count
|
||||
|
||||
issue_registry = ir.async_get(hass)
|
||||
assert len(issue_registry.issues) == 1
|
||||
|
Loading…
x
Reference in New Issue
Block a user