diff --git a/homeassistant/components/apple_tv/__init__.py b/homeassistant/components/apple_tv/__init__.py index 29bc5634b38..200984ca883 100644 --- a/homeassistant/components/apple_tv/__init__.py +++ b/homeassistant/components/apple_tv/__init__.py @@ -36,6 +36,7 @@ _LOGGER = logging.getLogger(__name__) DEFAULT_NAME = "Apple TV" +BACKOFF_TIME_LOWER_LIMIT = 15 # seconds BACKOFF_TIME_UPPER_LIMIT = 300 # Five minutes SIGNAL_CONNECTED = "apple_tv_connected" @@ -241,7 +242,11 @@ class AppleTVManager: if self.atv is None: self._connection_attempts += 1 backoff = min( - randrange(2 ** self._connection_attempts), BACKOFF_TIME_UPPER_LIMIT + max( + BACKOFF_TIME_LOWER_LIMIT, + randrange(2 ** self._connection_attempts), + ), + BACKOFF_TIME_UPPER_LIMIT, ) _LOGGER.debug("Reconnecting in %d seconds", backoff) @@ -271,17 +276,12 @@ class AppleTVManager: return atvs[0] _LOGGER.debug( - "Failed to find device %s with address %s, trying to scan", + "Failed to find device %s with address %s", self.config_entry.title, address, ) - - atvs = await scan(self.hass.loop, identifier=identifiers, protocol=protocols) - if atvs: - return atvs[0] - - _LOGGER.debug("Failed to find device %s, trying later", self.config_entry.title) - + # We no longer multicast scan for the device since as soon as async_step_zeroconf runs, + # it will update the address and reload the config entry when the device is found. return None async def _connect(self, conf): diff --git a/homeassistant/components/apple_tv/config_flow.py b/homeassistant/components/apple_tv/config_flow.py index 16a757b2ebb..878483e0ce7 100644 --- a/homeassistant/components/apple_tv/config_flow.py +++ b/homeassistant/components/apple_tv/config_flow.py @@ -1,4 +1,5 @@ """Config flow for Apple TV integration.""" +import asyncio from collections import deque from ipaddress import ip_address import logging @@ -27,6 +28,8 @@ INPUT_PIN_SCHEMA = vol.Schema({vol.Required(CONF_PIN, default=None): int}) DEFAULT_START_OFF = False +DISCOVERY_AGGREGATION_TIME = 15 # seconds + async def device_scan(identifier, loop): """Scan for a specific device using identifier as filter.""" @@ -46,12 +49,13 @@ async def device_scan(identifier, loop): except ValueError: return None - for hosts in (_host_filter(), None): - scan_result = await scan(loop, timeout=3, hosts=hosts) - matches = [atv for atv in scan_result if _filter_device(atv)] + # If we have an address, only probe that address to avoid + # broadcast traffic on the network + scan_result = await scan(loop, timeout=3, hosts=_host_filter()) + matches = [atv for atv in scan_result if _filter_device(atv)] - if matches: - return matches[0], matches[0].all_identifiers + if matches: + return matches[0], matches[0].all_identifiers return None, None @@ -93,10 +97,12 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): existing config entry. If that's the case, the unique_id from that entry is re-used, otherwise the newly discovered identifier is used instead. """ + all_identifiers = set(self.atv.all_identifiers) for entry in self._async_current_entries(): - for identifier in self.atv.all_identifiers: - if identifier in entry.data.get(CONF_IDENTIFIERS, [entry.unique_id]): - return entry.unique_id + if all_identifiers.intersection( + entry.data.get(CONF_IDENTIFIERS, [entry.unique_id]) + ): + return entry.unique_id return self.atv.identifier async def async_step_reauth(self, user_input=None): @@ -149,22 +155,18 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): self, discovery_info: zeroconf.ZeroconfServiceInfo ) -> data_entry_flow.FlowResult: """Handle device found via zeroconf.""" + host = discovery_info.host + self._async_abort_entries_match({CONF_ADDRESS: host}) service_type = discovery_info.type[:-1] # Remove leading . name = discovery_info.name.replace(f".{service_type}.", "") properties = discovery_info.properties # Extract unique identifier from service - self.scan_filter = get_unique_id(service_type, name, properties) - if self.scan_filter is None: + unique_id = get_unique_id(service_type, name, properties) + if unique_id is None: return self.async_abort(reason="unknown") - # Scan for the device in order to extract _all_ unique identifiers assigned to - # it. Not doing it like this will yield multiple config flows for the same - # device, one per protocol, which is undesired. - return await self.async_find_device_wrapper(self.async_found_zeroconf_device) - - async def async_found_zeroconf_device(self, user_input=None): - """Handle device found after Zeroconf discovery.""" + # # Suppose we have a device with three services: A, B and C. Let's assume # service A is discovered by Zeroconf, triggering a device scan that also finds # service B but *not* C. An identifier is picked from one of the services and @@ -177,31 +179,63 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): # since both flows really represent the same device. They will however end up # as two separate flows. # - # To solve this, all identifiers found during a device scan is stored as + # To solve this, all identifiers are stored as # "all_identifiers" in the flow context. When a new service is discovered, the # code below will check these identifiers for all active flows and abort if a # match is found. Before aborting, the original flow is updated with any # potentially new identifiers. In the example above, when service C is # discovered, the identifier of service C will be inserted into # "all_identifiers" of the original flow (making the device complete). - for flow in self._async_in_progress(): - for identifier in self.atv.all_identifiers: - if identifier not in flow["context"].get("all_identifiers", []): - continue + # + # Wait DISCOVERY_AGGREGATION_TIME for multiple services to be + # discovered via zeroconf. Once the first service is discovered + # this allows other services to be discovered inside the time + # window before triggering a scan of the device. This prevents + # multiple scans of the device at the same time since each + # apple_tv device has multiple services that are discovered by + # zeroconf. + # + await asyncio.sleep(DISCOVERY_AGGREGATION_TIME) + self._async_check_in_progress_and_set_address(host, unique_id) + + # Scan for the device in order to extract _all_ unique identifiers assigned to + # it. Not doing it like this will yield multiple config flows for the same + # device, one per protocol, which is undesired. + self.scan_filter = host + return await self.async_find_device_wrapper(self.async_found_zeroconf_device) + + @callback + def _async_check_in_progress_and_set_address(self, host: str, unique_id: str): + """Check for in-progress flows and update them with identifiers if needed. + + This code must not await between checking in progress and setting the host + or it will have a race condition where no flows move forward. + """ + for flow in self._async_in_progress(include_uninitialized=True): + context = flow["context"] + if ( + context.get("source") != config_entries.SOURCE_ZEROCONF + or context.get(CONF_ADDRESS) != host + ): + continue + if ( + "all_identifiers" in context + and unique_id not in context["all_identifiers"] + ): # Add potentially new identifiers from this device to the existing flow - identifiers = set(flow["context"]["all_identifiers"]) - identifiers.update(self.atv.all_identifiers) - flow["context"]["all_identifiers"] = list(identifiers) - - raise data_entry_flow.AbortFlow("already_in_progress") + context["all_identifiers"].append(unique_id) + raise data_entry_flow.AbortFlow("already_in_progress") + self.context[CONF_ADDRESS] = host + async def async_found_zeroconf_device(self, user_input=None): + """Handle device found after Zeroconf discovery.""" self.context["all_identifiers"] = self.atv.all_identifiers - # Also abort if an integration with this identifier already exists await self.async_set_unique_id(self.device_identifier) - self._abort_if_unique_id_configured() - + # but be sure to update the address if its changed so the scanner + # will probe the new address + self._abort_if_unique_id_configured(updates={CONF_ADDRESS: self.atv.address}) self.context["identifier"] = self.unique_id return await self.async_step_confirm() @@ -245,14 +279,22 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): else model_str(dev_info.model) ), } - - if not allow_exist: - for identifier in self.atv.all_identifiers: - for entry in self._async_current_entries(): - if identifier in entry.data.get( - CONF_IDENTIFIERS, [entry.unique_id] - ): - raise DeviceAlreadyConfigured() + all_identifiers = set(self.atv.all_identifiers) + for entry in self._async_current_entries(): + if not all_identifiers.intersection( + entry.data.get(CONF_IDENTIFIERS, [entry.unique_id]) + ): + continue + if entry.data.get(CONF_ADDRESS) != self.atv.address: + self.hass.config_entries.async_update_entry( + entry, + data={**entry.data, CONF_ADDRESS: self.atv.address}, + ) + self.hass.async_create_task( + self.hass.config_entries.async_reload(entry.entry_id) + ) + if not allow_exist: + raise DeviceAlreadyConfigured() async def async_step_confirm(self, user_input=None): """Handle user-confirmation of discovered node.""" diff --git a/tests/components/apple_tv/common.py b/tests/components/apple_tv/common.py index 687ad256f3a..ddb8c1348d9 100644 --- a/tests/components/apple_tv/common.py +++ b/tests/components/apple_tv/common.py @@ -49,10 +49,10 @@ def create_conf(name, address, *services): return atv -def mrp_service(enabled=True): +def mrp_service(enabled=True, unique_id="mrpid"): """Create example MRP service.""" return conf.ManualService( - "mrpid", + unique_id, Protocol.MRP, 5555, {}, @@ -70,3 +70,14 @@ def airplay_service(): {}, pairing_requirement=const.PairingRequirement.Mandatory, ) + + +def raop_service(): + """Create example RAOP service.""" + return conf.ManualService( + "AABBCCDDEEFF", + Protocol.RAOP, + 7000, + {}, + pairing_requirement=const.PairingRequirement.Mandatory, + ) diff --git a/tests/components/apple_tv/conftest.py b/tests/components/apple_tv/conftest.py index c8a9725610c..504e2d22e1d 100644 --- a/tests/components/apple_tv/conftest.py +++ b/tests/components/apple_tv/conftest.py @@ -96,12 +96,19 @@ def full_device(mock_scan, dmap_pin): @pytest.fixture def mrp_device(mock_scan): """Mock pyatv.scan.""" - mock_scan.result.append( - create_conf( - "127.0.0.1", - "MRP Device", - mrp_service(), - ) + mock_scan.result.extend( + [ + create_conf( + "127.0.0.1", + "MRP Device", + mrp_service(), + ), + create_conf( + "127.0.0.2", + "MRP Device 2", + mrp_service(unique_id="unrelated"), + ), + ] ) yield mock_scan diff --git a/tests/components/apple_tv/test_config_flow.py b/tests/components/apple_tv/test_config_flow.py index 39403b837ca..1d2a5a459d4 100644 --- a/tests/components/apple_tv/test_config_flow.py +++ b/tests/components/apple_tv/test_config_flow.py @@ -1,6 +1,6 @@ """Test config flow.""" -from unittest.mock import patch +from unittest.mock import ANY, patch from pyatv import exceptions from pyatv.const import PairingRequirement, Protocol @@ -8,14 +8,15 @@ import pytest from homeassistant import config_entries, data_entry_flow from homeassistant.components import zeroconf +from homeassistant.components.apple_tv import CONF_ADDRESS, config_flow from homeassistant.components.apple_tv.const import CONF_START_OFF, DOMAIN -from .common import airplay_service, create_conf, mrp_service +from .common import airplay_service, create_conf, mrp_service, raop_service from tests.common import MockConfigEntry DMAP_SERVICE = zeroconf.ZeroconfServiceInfo( - host="mock_host", + host="127.0.0.1", hostname="mock_hostname", port=None, type="_touch-able._tcp.local.", @@ -24,6 +25,23 @@ DMAP_SERVICE = zeroconf.ZeroconfServiceInfo( ) +RAOP_SERVICE = zeroconf.ZeroconfServiceInfo( + host="127.0.0.1", + hostname="mock_hostname", + port=None, + type="_raop._tcp.local.", + name="AABBCCDDEEFF@Master Bed._raop._tcp.local.", + properties={"am": "AppleTV11,1"}, +) + + +@pytest.fixture(autouse=True) +def zero_aggregation_time(): + """Prevent the aggregation time from delaying the tests.""" + with patch.object(config_flow, "DISCOVERY_AGGREGATION_TIME", 0): + yield + + @pytest.fixture(autouse=True) def use_mocked_zeroconf(mock_async_zeroconf): """Mock zeroconf in all tests.""" @@ -507,7 +525,7 @@ async def test_zeroconf_unsupported_service_aborts(hass): DOMAIN, context={"source": config_entries.SOURCE_ZEROCONF}, data=zeroconf.ZeroconfServiceInfo( - host="mock_host", + host="127.0.0.1", hostname="mock_hostname", name="mock_name", port=None, @@ -521,11 +539,25 @@ async def test_zeroconf_unsupported_service_aborts(hass): async def test_zeroconf_add_mrp_device(hass, mrp_device, pairing): """Test add MRP device discovered by zeroconf.""" + unrelated_result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": config_entries.SOURCE_ZEROCONF}, + data=zeroconf.ZeroconfServiceInfo( + host="127.0.0.2", + hostname="mock_hostname", + port=None, + name="Kitchen", + properties={"UniqueIdentifier": "unrelated", "Name": "Kitchen"}, + type="_mediaremotetv._tcp.local.", + ), + ) + assert unrelated_result["type"] == data_entry_flow.RESULT_TYPE_FORM + result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": config_entries.SOURCE_ZEROCONF}, data=zeroconf.ZeroconfServiceInfo( - host="mock_host", + host="127.0.0.1", hostname="mock_hostname", port=None, name="Kitchen", @@ -586,6 +618,37 @@ async def test_zeroconf_add_dmap_device(hass, dmap_device, dmap_pin, pairing): } +async def test_zeroconf_ip_change(hass, mock_scan): + """Test that the config entry gets updated when the ip changes and reloads.""" + entry = MockConfigEntry( + domain="apple_tv", unique_id="mrpid", data={CONF_ADDRESS: "127.0.0.2"} + ) + unrelated_entry = MockConfigEntry( + domain="apple_tv", unique_id="unrelated", data={CONF_ADDRESS: "127.0.0.2"} + ) + unrelated_entry.add_to_hass(hass) + entry.add_to_hass(hass) + mock_scan.result = [ + create_conf("127.0.0.1", "Device", mrp_service(), airplay_service()) + ] + + with patch( + "homeassistant.components.apple_tv.async_setup_entry", return_value=True + ) as mock_async_setup: + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": config_entries.SOURCE_ZEROCONF}, + data=DMAP_SERVICE, + ) + await hass.async_block_till_done() + + assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT + assert result["reason"] == "already_configured" + assert len(mock_async_setup.mock_calls) == 2 + assert entry.data[CONF_ADDRESS] == "127.0.0.1" + assert unrelated_entry.data[CONF_ADDRESS] == "127.0.0.2" + + async def test_zeroconf_add_existing_aborts(hass, dmap_device): """Test start new zeroconf flow while existing flow is active aborts.""" await hass.config_entries.flow.async_init( @@ -638,7 +701,7 @@ async def test_zeroconf_abort_if_other_in_progress(hass, mock_scan): DOMAIN, context={"source": config_entries.SOURCE_ZEROCONF}, data=zeroconf.ZeroconfServiceInfo( - host="mock_host", + host="127.0.0.1", hostname="mock_hostname", port=None, type="_airplay._tcp.local.", @@ -658,7 +721,7 @@ async def test_zeroconf_abort_if_other_in_progress(hass, mock_scan): DOMAIN, context={"source": config_entries.SOURCE_ZEROCONF}, data=zeroconf.ZeroconfServiceInfo( - host="mock_host", + host="127.0.0.1", hostname="mock_hostname", port=None, type="_mediaremotetv._tcp.local.", @@ -681,7 +744,7 @@ async def test_zeroconf_missing_device_during_protocol_resolve( DOMAIN, context={"source": config_entries.SOURCE_ZEROCONF}, data=zeroconf.ZeroconfServiceInfo( - host="mock_host", + host="127.0.0.1", hostname="mock_hostname", port=None, type="_airplay._tcp.local.", @@ -700,7 +763,7 @@ async def test_zeroconf_missing_device_during_protocol_resolve( DOMAIN, context={"source": config_entries.SOURCE_ZEROCONF}, data=zeroconf.ZeroconfServiceInfo( - host="mock_host", + host="127.0.0.1", hostname="mock_hostname", port=None, type="_mediaremotetv._tcp.local.", @@ -733,7 +796,7 @@ async def test_zeroconf_additional_protocol_resolve_failure( DOMAIN, context={"source": config_entries.SOURCE_ZEROCONF}, data=zeroconf.ZeroconfServiceInfo( - host="mock_host", + host="127.0.0.1", hostname="mock_hostname", port=None, type="_airplay._tcp.local.", @@ -752,7 +815,7 @@ async def test_zeroconf_additional_protocol_resolve_failure( DOMAIN, context={"source": config_entries.SOURCE_ZEROCONF}, data=zeroconf.ZeroconfServiceInfo( - host="mock_host", + host="127.0.0.1", hostname="mock_hostname", port=None, type="_mediaremotetv._tcp.local.", @@ -785,7 +848,7 @@ async def test_zeroconf_pair_additionally_found_protocols( DOMAIN, context={"source": config_entries.SOURCE_ZEROCONF}, data=zeroconf.ZeroconfServiceInfo( - host="mock_host", + host="127.0.0.1", hostname="mock_hostname", port=None, type="_airplay._tcp.local.", @@ -793,9 +856,26 @@ async def test_zeroconf_pair_additionally_found_protocols( properties={"deviceid": "airplayid"}, ), ) + assert result["type"] == data_entry_flow.RESULT_TYPE_FORM + await hass.async_block_till_done() mock_scan.result = [ - create_conf("127.0.0.1", "Device", mrp_service(), airplay_service()) + create_conf("127.0.0.1", "Device", raop_service(), airplay_service()) + ] + + # Find the same device again, but now also with RAOP service. The first flow should + # be updated with the RAOP service. + await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": config_entries.SOURCE_ZEROCONF}, + data=RAOP_SERVICE, + ) + await hass.async_block_till_done() + + mock_scan.result = [ + create_conf( + "127.0.0.1", "Device", raop_service(), mrp_service(), airplay_service() + ) ] # Find the same device again, but now also with MRP service. The first flow should @@ -804,7 +884,7 @@ async def test_zeroconf_pair_additionally_found_protocols( DOMAIN, context={"source": config_entries.SOURCE_ZEROCONF}, data=zeroconf.ZeroconfServiceInfo( - host="mock_host", + host="127.0.0.1", hostname="mock_hostname", port=None, type="_mediaremotetv._tcp.local.", @@ -812,29 +892,41 @@ async def test_zeroconf_pair_additionally_found_protocols( properties={"UniqueIdentifier": "mrpid", "Name": "Kitchen"}, ), ) + await hass.async_block_till_done() - # Verify that _both_ protocols are paired + # Verify that all protocols are paired result2 = await hass.config_entries.flow.async_configure( result["flow_id"], {}, ) - assert result2["type"] == data_entry_flow.RESULT_TYPE_FORM - assert result2["step_id"] == "pair_with_pin" - assert result2["description_placeholders"] == {"protocol": "MRP"} + assert result2["type"] == data_entry_flow.RESULT_TYPE_FORM + assert result2["step_id"] == "pair_no_pin" + assert result2["description_placeholders"] == {"pin": ANY, "protocol": "RAOP"} + + # Verify that all protocols are paired result3 = await hass.config_entries.flow.async_configure( result["flow_id"], - {"pin": 1234}, + {}, ) + assert result3["type"] == data_entry_flow.RESULT_TYPE_FORM assert result3["step_id"] == "pair_with_pin" - assert result3["description_placeholders"] == {"protocol": "AirPlay"} + assert result3["description_placeholders"] == {"protocol": "MRP"} result4 = await hass.config_entries.flow.async_configure( result["flow_id"], {"pin": 1234}, ) - assert result4["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY + assert result4["type"] == data_entry_flow.RESULT_TYPE_FORM + assert result4["step_id"] == "pair_with_pin" + assert result4["description_placeholders"] == {"protocol": "AirPlay"} + + result5 = await hass.config_entries.flow.async_configure( + result["flow_id"], + {"pin": 1234}, + ) + assert result5["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY # Re-configuration