diff --git a/homeassistant/components/apple_tv/config_flow.py b/homeassistant/components/apple_tv/config_flow.py index 545d9c5fd90..4b630fc777b 100644 --- a/homeassistant/components/apple_tv/config_flow.py +++ b/homeassistant/components/apple_tv/config_flow.py @@ -1,4 +1,6 @@ """Config flow for Apple TV integration.""" +from __future__ import annotations + import asyncio from collections import deque from ipaddress import ip_address @@ -98,12 +100,19 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): re-used, otherwise the newly discovered identifier is used instead. """ all_identifiers = set(self.atv.all_identifiers) + if unique_id := self._entry_unique_id_from_identifers(all_identifiers): + return unique_id + return self.atv.identifier + + @callback + def _entry_unique_id_from_identifers(self, all_identifiers: set[str]) -> str | None: + """Search existing entries for an identifier and return the unique id.""" for entry in self._async_current_entries(): if all_identifiers.intersection( entry.data.get(CONF_IDENTIFIERS, [entry.unique_id]) ): return entry.unique_id - return self.atv.identifier + return None async def async_step_reauth(self, user_input=None): """Handle initial step when updating invalid credentials.""" @@ -166,6 +175,20 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): if unique_id is None: return self.async_abort(reason="unknown") + if existing_unique_id := self._entry_unique_id_from_identifers({unique_id}): + await self.async_set_unique_id(existing_unique_id) + self._abort_if_unique_id_configured(updates={CONF_ADDRESS: host}) + + self._async_abort_entries_match({CONF_ADDRESS: host}) + await self._async_aggregate_discoveries(host, unique_id) + # Scan for the device in order to extract _all_ unique identifiers assigned to + # it. Not doing it like this will yield multiple config flows for the same + # device, one per protocol, which is undesired. + self.scan_filter = host + return await self.async_find_device_wrapper(self.async_found_zeroconf_device) + + async def _async_aggregate_discoveries(self, host: str, unique_id: str) -> None: + """Wait for multiple zeroconf services to be discovered an aggregate them.""" # # Suppose we have a device with three services: A, B and C. Let's assume # service A is discovered by Zeroconf, triggering a device scan that also finds @@ -195,23 +218,18 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): # apple_tv device has multiple services that are discovered by # zeroconf. # + self._async_check_and_update_in_progress(host, unique_id) await asyncio.sleep(DISCOVERY_AGGREGATION_TIME) - - self._async_check_in_progress_and_set_address(host, unique_id) - - # Scan for the device in order to extract _all_ unique identifiers assigned to - # it. Not doing it like this will yield multiple config flows for the same - # device, one per protocol, which is undesired. - self.scan_filter = host - return await self.async_find_device_wrapper(self.async_found_zeroconf_device) + # Check again after sleeping in case another flow + # has made progress while we yielded to the event loop + self._async_check_and_update_in_progress(host, unique_id) + # Host must only be set AFTER checking and updating in progress + # flows or we will have a race condition where no flows move forward. + self.context[CONF_ADDRESS] = host @callback - def _async_check_in_progress_and_set_address(self, host: str, unique_id: str): - """Check for in-progress flows and update them with identifiers if needed. - - This code must not await between checking in progress and setting the host - or it will have a race condition where no flows move forward. - """ + def _async_check_and_update_in_progress(self, host: str, unique_id: str) -> None: + """Check for in-progress flows and update them with identifiers if needed.""" for flow in self._async_in_progress(include_uninitialized=True): context = flow["context"] if ( @@ -226,7 +244,6 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): # Add potentially new identifiers from this device to the existing flow context["all_identifiers"].append(unique_id) raise data_entry_flow.AbortFlow("already_in_progress") - self.context[CONF_ADDRESS] = host async def async_found_zeroconf_device(self, user_input=None): """Handle device found after Zeroconf discovery.""" diff --git a/tests/components/apple_tv/test_config_flow.py b/tests/components/apple_tv/test_config_flow.py index e9f352041cb..b4811e57739 100644 --- a/tests/components/apple_tv/test_config_flow.py +++ b/tests/components/apple_tv/test_config_flow.py @@ -10,7 +10,11 @@ import pytest from homeassistant import config_entries, data_entry_flow from homeassistant.components import zeroconf from homeassistant.components.apple_tv import CONF_ADDRESS, config_flow -from homeassistant.components.apple_tv.const import CONF_START_OFF, DOMAIN +from homeassistant.components.apple_tv.const import ( + CONF_IDENTIFIERS, + CONF_START_OFF, + DOMAIN, +) from .common import airplay_service, create_conf, mrp_service, raop_service @@ -652,6 +656,45 @@ async def test_zeroconf_ip_change(hass, mock_scan): assert unrelated_entry.data[CONF_ADDRESS] == "127.0.0.2" +async def test_zeroconf_ip_change_via_secondary_identifier(hass, mock_scan): + """Test that the config entry gets updated when the ip changes and reloads. + + Instead of checking only the unique id, all the identifiers + in the config entry are checked + """ + entry = MockConfigEntry( + domain="apple_tv", + unique_id="aa:bb:cc:dd:ee:ff", + data={CONF_IDENTIFIERS: ["mrpid"], CONF_ADDRESS: "127.0.0.2"}, + ) + unrelated_entry = MockConfigEntry( + domain="apple_tv", unique_id="unrelated", data={CONF_ADDRESS: "127.0.0.2"} + ) + unrelated_entry.add_to_hass(hass) + entry.add_to_hass(hass) + mock_scan.result = [ + create_conf( + IPv4Address("127.0.0.1"), "Device", mrp_service(), airplay_service() + ) + ] + + with patch( + "homeassistant.components.apple_tv.async_setup_entry", return_value=True + ) as mock_async_setup: + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": config_entries.SOURCE_ZEROCONF}, + data=DMAP_SERVICE, + ) + await hass.async_block_till_done() + + assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT + assert result["reason"] == "already_configured" + assert len(mock_async_setup.mock_calls) == 2 + assert entry.data[CONF_ADDRESS] == "127.0.0.1" + assert unrelated_entry.data[CONF_ADDRESS] == "127.0.0.2" + + async def test_zeroconf_add_existing_aborts(hass, dmap_device): """Test start new zeroconf flow while existing flow is active aborts.""" await hass.config_entries.flow.async_init(