Fix updating apple_tv addresses (#61724)

This commit is contained in:
J. Nick Koston 2021-12-13 21:38:22 +01:00 committed by GitHub
parent 7adffe6927
commit 65ec251309
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 42 additions and 14 deletions

View File

@ -235,7 +235,9 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
await self.async_set_unique_id(self.device_identifier) await self.async_set_unique_id(self.device_identifier)
# but be sure to update the address if its changed so the scanner # but be sure to update the address if its changed so the scanner
# will probe the new address # will probe the new address
self._abort_if_unique_id_configured(updates={CONF_ADDRESS: self.atv.address}) self._abort_if_unique_id_configured(
updates={CONF_ADDRESS: str(self.atv.address)}
)
self.context["identifier"] = self.unique_id self.context["identifier"] = self.unique_id
return await self.async_step_confirm() return await self.async_step_confirm()
@ -280,15 +282,16 @@ class AppleTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
), ),
} }
all_identifiers = set(self.atv.all_identifiers) all_identifiers = set(self.atv.all_identifiers)
discovered_ip_address = str(self.atv.address)
for entry in self._async_current_entries(): for entry in self._async_current_entries():
if not all_identifiers.intersection( if not all_identifiers.intersection(
entry.data.get(CONF_IDENTIFIERS, [entry.unique_id]) entry.data.get(CONF_IDENTIFIERS, [entry.unique_id])
): ):
continue continue
if entry.data.get(CONF_ADDRESS) != self.atv.address: if entry.data.get(CONF_ADDRESS) != discovered_ip_address:
self.hass.config_entries.async_update_entry( self.hass.config_entries.async_update_entry(
entry, entry,
data={**entry.data, CONF_ADDRESS: self.atv.address}, data={**entry.data, CONF_ADDRESS: discovered_ip_address},
) )
self.hass.async_create_task( self.hass.async_create_task(
self.hass.config_entries.async_reload(entry.entry_id) self.hass.config_entries.async_reload(entry.entry_id)

View File

@ -1,5 +1,6 @@
"""Test config flow.""" """Test config flow."""
from ipaddress import IPv4Address
from unittest.mock import ANY, patch from unittest.mock import ANY, patch
from pyatv import exceptions from pyatv import exceptions
@ -629,7 +630,9 @@ async def test_zeroconf_ip_change(hass, mock_scan):
unrelated_entry.add_to_hass(hass) unrelated_entry.add_to_hass(hass)
entry.add_to_hass(hass) entry.add_to_hass(hass)
mock_scan.result = [ mock_scan.result = [
create_conf("127.0.0.1", "Device", mrp_service(), airplay_service()) create_conf(
IPv4Address("127.0.0.1"), "Device", mrp_service(), airplay_service()
)
] ]
with patch( with patch(
@ -695,7 +698,9 @@ async def test_zeroconf_unexpected_error(hass, mock_scan):
async def test_zeroconf_abort_if_other_in_progress(hass, mock_scan): async def test_zeroconf_abort_if_other_in_progress(hass, mock_scan):
"""Test discovering unsupported zeroconf service.""" """Test discovering unsupported zeroconf service."""
mock_scan.result = [create_conf("127.0.0.1", "Device", airplay_service())] mock_scan.result = [
create_conf(IPv4Address("127.0.0.1"), "Device", airplay_service())
]
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(
DOMAIN, DOMAIN,
@ -714,7 +719,9 @@ async def test_zeroconf_abort_if_other_in_progress(hass, mock_scan):
assert result["step_id"] == "confirm" assert result["step_id"] == "confirm"
mock_scan.result = [ mock_scan.result = [
create_conf("127.0.0.1", "Device", mrp_service(), airplay_service()) create_conf(
IPv4Address("127.0.0.1"), "Device", mrp_service(), airplay_service()
)
] ]
result2 = await hass.config_entries.flow.async_init( result2 = await hass.config_entries.flow.async_init(
@ -737,7 +744,9 @@ async def test_zeroconf_missing_device_during_protocol_resolve(
hass, mock_scan, pairing, mock_zeroconf hass, mock_scan, pairing, mock_zeroconf
): ):
"""Test discovery after service been added to existing flow with missing device.""" """Test discovery after service been added to existing flow with missing device."""
mock_scan.result = [create_conf("127.0.0.1", "Device", airplay_service())] mock_scan.result = [
create_conf(IPv4Address("127.0.0.1"), "Device", airplay_service())
]
# Find device with AirPlay service and set up flow for it # Find device with AirPlay service and set up flow for it
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(
@ -754,7 +763,9 @@ async def test_zeroconf_missing_device_during_protocol_resolve(
) )
mock_scan.result = [ mock_scan.result = [
create_conf("127.0.0.1", "Device", mrp_service(), airplay_service()) create_conf(
IPv4Address("127.0.0.1"), "Device", mrp_service(), airplay_service()
)
] ]
# Find the same device again, but now also with MRP service. The first flow should # Find the same device again, but now also with MRP service. The first flow should
@ -789,7 +800,9 @@ async def test_zeroconf_additional_protocol_resolve_failure(
hass, mock_scan, pairing, mock_zeroconf hass, mock_scan, pairing, mock_zeroconf
): ):
"""Test discovery with missing service.""" """Test discovery with missing service."""
mock_scan.result = [create_conf("127.0.0.1", "Device", airplay_service())] mock_scan.result = [
create_conf(IPv4Address("127.0.0.1"), "Device", airplay_service())
]
# Find device with AirPlay service and set up flow for it # Find device with AirPlay service and set up flow for it
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(
@ -806,7 +819,9 @@ async def test_zeroconf_additional_protocol_resolve_failure(
) )
mock_scan.result = [ mock_scan.result = [
create_conf("127.0.0.1", "Device", mrp_service(), airplay_service()) create_conf(
IPv4Address("127.0.0.1"), "Device", mrp_service(), airplay_service()
)
] ]
# Find the same device again, but now also with MRP service. The first flow should # Find the same device again, but now also with MRP service. The first flow should
@ -824,7 +839,9 @@ async def test_zeroconf_additional_protocol_resolve_failure(
), ),
) )
mock_scan.result = [create_conf("127.0.0.1", "Device", airplay_service())] mock_scan.result = [
create_conf(IPv4Address("127.0.0.1"), "Device", airplay_service())
]
# Number of services found during initial scan (1) will not match the updated count # Number of services found during initial scan (1) will not match the updated count
# (2), so it will trigger a re-scan to find all services. This will however fail # (2), so it will trigger a re-scan to find all services. This will however fail
@ -841,7 +858,9 @@ async def test_zeroconf_pair_additionally_found_protocols(
hass, mock_scan, pairing, mock_zeroconf hass, mock_scan, pairing, mock_zeroconf
): ):
"""Test discovered protocols are merged to original flow.""" """Test discovered protocols are merged to original flow."""
mock_scan.result = [create_conf("127.0.0.1", "Device", airplay_service())] mock_scan.result = [
create_conf(IPv4Address("127.0.0.1"), "Device", airplay_service())
]
# Find device with AirPlay service and set up flow for it # Find device with AirPlay service and set up flow for it
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(
@ -860,7 +879,9 @@ async def test_zeroconf_pair_additionally_found_protocols(
await hass.async_block_till_done() await hass.async_block_till_done()
mock_scan.result = [ mock_scan.result = [
create_conf("127.0.0.1", "Device", raop_service(), airplay_service()) create_conf(
IPv4Address("127.0.0.1"), "Device", raop_service(), airplay_service()
)
] ]
# Find the same device again, but now also with RAOP service. The first flow should # Find the same device again, but now also with RAOP service. The first flow should
@ -874,7 +895,11 @@ async def test_zeroconf_pair_additionally_found_protocols(
mock_scan.result = [ mock_scan.result = [
create_conf( create_conf(
"127.0.0.1", "Device", raop_service(), mrp_service(), airplay_service() IPv4Address("127.0.0.1"),
"Device",
raop_service(),
mrp_service(),
airplay_service(),
) )
] ]