Add migration to migrate 'homewizard_energy' to 'homewizard' (#65594)

This commit is contained in:
Duco Sebel 2022-02-04 18:12:35 +01:00 committed by GitHub
parent 8245ff7473
commit a97e69196c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 193 additions and 4 deletions

View File

@ -3,10 +3,11 @@ import logging
from aiohwenergy import DisabledError
from homeassistant.config_entries import ConfigEntry
from homeassistant.config_entries import SOURCE_IMPORT, ConfigEntry
from homeassistant.const import CONF_IP_ADDRESS
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryNotReady
from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.update_coordinator import UpdateFailed
from .const import DOMAIN, PLATFORMS
@ -20,6 +21,51 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
_LOGGER.debug("__init__ async_setup_entry")
# Migrate `homewizard_energy` (custom_component) to `homewizard`
if entry.source == SOURCE_IMPORT and "old_config_entry_id" in entry.data:
# Remove the old config entry ID from the entry data so we don't try this again
# on the next setup
data = entry.data.copy()
old_config_entry_id = data.pop("old_config_entry_id")
hass.config_entries.async_update_entry(entry, data=data)
_LOGGER.debug(
(
"Setting up imported homewizard_energy entry %s for the first time as "
"homewizard entry %s"
),
old_config_entry_id,
entry.entry_id,
)
ent_reg = er.async_get(hass)
for entity in er.async_entries_for_config_entry(ent_reg, old_config_entry_id):
_LOGGER.debug("Removing %s", entity.entity_id)
ent_reg.async_remove(entity.entity_id)
_LOGGER.debug("Re-creating %s for the new config entry", entity.entity_id)
# We will precreate the entity so that any customizations can be preserved
new_entity = ent_reg.async_get_or_create(
entity.domain,
DOMAIN,
entity.unique_id,
suggested_object_id=entity.entity_id.split(".")[1],
disabled_by=entity.disabled_by,
config_entry=entry,
original_name=entity.original_name,
original_icon=entity.original_icon,
)
_LOGGER.debug("Re-created %s", new_entity.entity_id)
# If there are customizations on the old entity, apply them to the new one
if entity.name or entity.icon:
ent_reg.async_update_entity(
new_entity.entity_id, name=entity.name, icon=entity.icon
)
# Remove the old config entry and now the entry is fully migrated
hass.async_create_task(hass.config_entries.async_remove(old_config_entry_id))
# Create coordinator
coordinator = Coordinator(hass, entry.data[CONF_IP_ADDRESS])
try:

View File

@ -28,6 +28,21 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
"""Initialize the HomeWizard config flow."""
self.config: dict[str, str | int] = {}
async def async_step_import(self, import_config: dict) -> FlowResult:
"""Handle a flow initiated by older `homewizard_energy` component."""
_LOGGER.debug("config_flow async_step_import")
self.hass.components.persistent_notification.async_create(
(
"The custom integration of HomeWizard Energy has been migrated to core. "
"You can safely remove the custom integration from the custom_integrations folder."
),
"HomeWizard Energy",
f"homewizard_energy_to_{DOMAIN}",
)
return await self.async_step_user({CONF_IP_ADDRESS: import_config["host"]})
async def async_step_user(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
@ -59,12 +74,17 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
}
)
data: dict[str, str] = {CONF_IP_ADDRESS: user_input[CONF_IP_ADDRESS]}
if self.source == config_entries.SOURCE_IMPORT:
old_config_entry_id = self.context["old_config_entry_id"]
assert self.hass.config_entries.async_get_entry(old_config_entry_id)
data["old_config_entry_id"] = old_config_entry_id
# Add entry
return self.async_create_entry(
title=f"{device_info[CONF_PRODUCT_NAME]} ({device_info[CONF_SERIAL]})",
data={
CONF_IP_ADDRESS: user_input[CONF_IP_ADDRESS],
},
data=data,
)
async def async_step_zeroconf(

View File

@ -12,6 +12,8 @@ from homeassistant.data_entry_flow import RESULT_TYPE_ABORT, RESULT_TYPE_CREATE_
from .generator import get_mock_device
from tests.common import MockConfigEntry
_LOGGER = logging.getLogger(__name__)
@ -88,6 +90,37 @@ async def test_discovery_flow_works(hass, aioclient_mock):
assert result["result"].unique_id == "HWE-P1_aabbccddeeff"
async def test_config_flow_imports_entry(aioclient_mock, hass):
"""Test config flow accepts imported configuration."""
device = get_mock_device()
mock_entry = MockConfigEntry(domain="homewizard_energy", data={"host": "1.2.3.4"})
mock_entry.add_to_hass(hass)
with patch("aiohwenergy.HomeWizardEnergy", return_value=device,), patch(
"homeassistant.components.homewizard.async_setup_entry",
return_value=True,
) as mock_setup_entry:
result = await hass.config_entries.flow.async_init(
DOMAIN,
context={
"source": config_entries.SOURCE_IMPORT,
"old_config_entry_id": mock_entry.entry_id,
},
data=mock_entry.data,
)
assert result["type"] == "create_entry"
assert result["title"] == f"{device.device.product_name} (aabbccddeeff)"
assert result["data"][CONF_IP_ADDRESS] == "1.2.3.4"
assert len(hass.config_entries.async_entries(DOMAIN)) == 1
assert len(device.initialize.mock_calls) == 1
assert len(device.close.mock_calls) == 1
assert len(mock_setup_entry.mock_calls) == 1
async def test_discovery_disabled_api(hass, aioclient_mock):
"""Test discovery detecting disabled api."""

View File

@ -4,9 +4,11 @@ from unittest.mock import patch
from aiohwenergy import AiohwenergyException, DisabledError
from homeassistant import config_entries
from homeassistant.components.homewizard.const import DOMAIN
from homeassistant.config_entries import ConfigEntryState
from homeassistant.const import CONF_IP_ADDRESS
from homeassistant.helpers import entity_registry as er
from .generator import get_mock_device
@ -68,6 +70,94 @@ async def test_load_failed_host_unavailable(aioclient_mock, hass):
assert entry.state is ConfigEntryState.SETUP_RETRY
async def test_init_accepts_and_migrates_old_entry(aioclient_mock, hass):
"""Test config flow accepts imported configuration."""
device = get_mock_device()
# Add original entry
original_entry = MockConfigEntry(
domain=DOMAIN,
data={CONF_IP_ADDRESS: "1.2.3.4"},
entry_id="old_id",
)
original_entry.add_to_hass(hass)
# Give it some entities to see of they migrate properly
ent_reg = er.async_get(hass)
old_entity_active_power = ent_reg.async_get_or_create(
"sensor",
"homewizard_energy",
"p1_active_power_unique_id",
config_entry=original_entry,
original_name="Active Power",
suggested_object_id="p1_active_power",
)
old_entity_switch = ent_reg.async_get_or_create(
"switch",
"homewizard_energy",
"socket_switch_unique_id",
config_entry=original_entry,
original_name="Switch",
suggested_object_id="socket_switch",
)
old_entity_disabled_sensor = ent_reg.async_get_or_create(
"sensor",
"homewizard_energy",
"socket_disabled_unique_id",
config_entry=original_entry,
original_name="Switch Disabled",
suggested_object_id="socket_disabled",
disabled_by=er.DISABLED_USER,
)
# Update some user-customs
ent_reg.async_update_entity(old_entity_active_power.entity_id, name="new_name")
ent_reg.async_update_entity(old_entity_switch.entity_id, icon="new_icon")
imported_entry = MockConfigEntry(
domain=DOMAIN,
data={CONF_IP_ADDRESS: "1.2.3.4", "old_config_entry_id": "old_id"},
source=config_entries.SOURCE_IMPORT,
entry_id="new_id",
)
imported_entry.add_to_hass(hass)
# Add the entry_id to trigger migration
with patch(
"aiohwenergy.HomeWizardEnergy",
return_value=device,
):
await hass.config_entries.async_setup(imported_entry.entry_id)
await hass.async_block_till_done()
assert original_entry.state is ConfigEntryState.NOT_LOADED
assert imported_entry.state is ConfigEntryState.LOADED
# Check if new entities are migrated
new_entity_active_power = ent_reg.async_get(old_entity_active_power.entity_id)
assert new_entity_active_power.platform == DOMAIN
assert new_entity_active_power.name == "new_name"
assert new_entity_active_power.icon is None
assert new_entity_active_power.original_name == "Active Power"
assert new_entity_active_power.unique_id == "p1_active_power_unique_id"
assert new_entity_active_power.disabled_by is None
new_entity_switch = ent_reg.async_get(old_entity_switch.entity_id)
assert new_entity_switch.platform == DOMAIN
assert new_entity_switch.name is None
assert new_entity_switch.icon == "new_icon"
assert new_entity_switch.original_name == "Switch"
assert new_entity_switch.unique_id == "socket_switch_unique_id"
assert new_entity_switch.disabled_by is None
new_entity_disabled_sensor = ent_reg.async_get(old_entity_disabled_sensor.entity_id)
assert new_entity_disabled_sensor.platform == DOMAIN
assert new_entity_disabled_sensor.name is None
assert new_entity_disabled_sensor.original_name == "Switch Disabled"
assert new_entity_disabled_sensor.unique_id == "socket_disabled_unique_id"
assert new_entity_disabled_sensor.disabled_by == er.DISABLED_USER
async def test_load_detect_api_disabled(aioclient_mock, hass):
"""Test setup detects disabled API."""