Dismiss discoveries when the underlying device disappears (#88340)

* Implement discovery removals

Bluetooth, HomeKit, SSDP, and Zeroconf now implement
dismissing discoveries when the underlying discovered
device disappears

* cover

* add zeroconf test

* cover

* cover bluetooth

* fix rediscover
This commit is contained in:
J. Nick Koston 2023-02-17 14:51:19 -06:00 committed by GitHub
parent 710b250c1d
commit 331102e592
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 432 additions and 11 deletions

View File

@ -315,6 +315,8 @@ class BluetoothManager:
# the device from all the interval tracking since it is no longer # the device from all the interval tracking since it is no longer
# available for both connectable and non-connectable # available for both connectable and non-connectable
tracker.async_remove_address(address) tracker.async_remove_address(address)
self._integration_matcher.async_clear_address(address)
self._async_dismiss_discoveries(address)
service_info = history.pop(address) service_info = history.pop(address)
@ -327,6 +329,14 @@ class BluetoothManager:
except Exception: # pylint: disable=broad-except except Exception: # pylint: disable=broad-except
_LOGGER.exception("Error in unavailable callback") _LOGGER.exception("Error in unavailable callback")
def _async_dismiss_discoveries(self, address: str) -> None:
"""Dismiss all discoveries for the given address."""
for flow in self.hass.config_entries.flow.async_progress_by_init_data_type(
BluetoothServiceInfoBleak,
lambda service_info: bool(service_info.address == address),
):
self.hass.config_entries.flow.async_abort(flow["flow_id"])
def _prefer_previous_adv_from_different_source( def _prefer_previous_adv_from_different_source(
self, self,
old: BluetoothServiceInfoBleak, old: BluetoothServiceInfoBleak,

View File

@ -518,7 +518,11 @@ class Scanner:
CaseInsensitiveDict(combined_headers.as_dict(), **info_desc) CaseInsensitiveDict(combined_headers.as_dict(), **info_desc)
) )
if not callbacks and not matching_domains: if (
not callbacks
and not matching_domains
and source != SsdpSource.ADVERTISEMENT_BYEBYE
):
return return
discovery_info = discovery_info_from_headers_and_description( discovery_info = discovery_info_from_headers_and_description(
@ -534,6 +538,7 @@ class Scanner:
# Config flows should only be created for alive/update messages from alive devices # Config flows should only be created for alive/update messages from alive devices
if source == SsdpSource.ADVERTISEMENT_BYEBYE: if source == SsdpSource.ADVERTISEMENT_BYEBYE:
self._async_dismiss_discoveries(discovery_info)
return return
_LOGGER.debug("Discovery info: %s", discovery_info) _LOGGER.debug("Discovery info: %s", discovery_info)
@ -548,6 +553,19 @@ class Scanner:
discovery_info, discovery_info,
) )
def _async_dismiss_discoveries(
self, byebye_discovery_info: SsdpServiceInfo
) -> None:
"""Dismiss all discoveries for the given address."""
for flow in self.hass.config_entries.flow.async_progress_by_init_data_type(
SsdpServiceInfo,
lambda service_info: bool(
service_info.ssdp_st == byebye_discovery_info.ssdp_st
and service_info.ssdp_location == byebye_discovery_info.ssdp_location
),
):
self.hass.config_entries.flow.async_abort(flow["flow_id"])
async def _async_get_description_dict( async def _async_get_description_dict(
self, location: str | None self, location: str | None
) -> Mapping[str, str]: ) -> Mapping[str, str]:

View File

@ -378,6 +378,14 @@ class ZeroconfDiscovery:
if self.async_service_browser: if self.async_service_browser:
await self.async_service_browser.async_cancel() await self.async_service_browser.async_cancel()
def _async_dismiss_discoveries(self, name: str) -> None:
"""Dismiss all discoveries for the given name."""
for flow in self.hass.config_entries.flow.async_progress_by_init_data_type(
ZeroconfServiceInfo,
lambda service_info: bool(service_info.name == name),
):
self.hass.config_entries.flow.async_abort(flow["flow_id"])
@callback @callback
def async_service_update( def async_service_update(
self, self,
@ -395,6 +403,7 @@ class ZeroconfDiscovery:
) )
if state_change == ServiceStateChange.Removed: if state_change == ServiceStateChange.Removed:
self._async_dismiss_discoveries(name)
return return
try: try:

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
import abc import abc
from collections.abc import Iterable, Mapping from collections.abc import Callable, Iterable, Mapping
import copy import copy
from dataclasses import dataclass from dataclasses import dataclass
import logging import logging
@ -138,6 +138,7 @@ class FlowManager(abc.ABC):
self.hass = hass self.hass = hass
self._progress: dict[str, FlowHandler] = {} self._progress: dict[str, FlowHandler] = {}
self._handler_progress_index: dict[str, set[str]] = {} self._handler_progress_index: dict[str, set[str]] = {}
self._init_data_process_index: dict[type, set[str]] = {}
@abc.abstractmethod @abc.abstractmethod
async def async_create_flow( async def async_create_flow(
@ -198,6 +199,23 @@ class FlowManager(abc.ABC):
self._async_progress_by_handler(handler), include_uninitialized self._async_progress_by_handler(handler), include_uninitialized
) )
@callback
def async_progress_by_init_data_type(
self,
init_data_type: type,
matcher: Callable[[Any], bool],
include_uninitialized: bool = False,
) -> list[FlowResult]:
"""Return flows in progress init matching by data type as a partial FlowResult."""
return _async_flow_handler_to_flow_result(
(
self._progress[flow_id]
for flow_id in self._init_data_process_index.get(init_data_type, {})
if matcher(self._progress[flow_id].init_data)
),
include_uninitialized,
)
@callback @callback
def _async_progress_by_handler(self, handler: str) -> list[FlowHandler]: def _async_progress_by_handler(self, handler: str) -> list[FlowHandler]:
"""Return the flows in progress by handler.""" """Return the flows in progress by handler."""
@ -301,19 +319,33 @@ class FlowManager(abc.ABC):
@callback @callback
def _async_add_flow_progress(self, flow: FlowHandler) -> None: def _async_add_flow_progress(self, flow: FlowHandler) -> None:
"""Add a flow to in progress.""" """Add a flow to in progress."""
if flow.init_data is not None:
init_data_type = type(flow.init_data)
self._init_data_process_index.setdefault(init_data_type, set()).add(
flow.flow_id
)
self._progress[flow.flow_id] = flow self._progress[flow.flow_id] = flow
self._handler_progress_index.setdefault(flow.handler, set()).add(flow.flow_id) self._handler_progress_index.setdefault(flow.handler, set()).add(flow.flow_id)
@callback
def _async_remove_flow_from_index(self, flow: FlowHandler) -> None:
"""Remove a flow from in progress."""
if flow.init_data is not None:
init_data_type = type(flow.init_data)
self._init_data_process_index[init_data_type].remove(flow.flow_id)
if not self._init_data_process_index[init_data_type]:
del self._init_data_process_index[init_data_type]
handler = flow.handler
self._handler_progress_index[handler].remove(flow.flow_id)
if not self._handler_progress_index[handler]:
del self._handler_progress_index[handler]
@callback @callback
def _async_remove_flow_progress(self, flow_id: str) -> None: def _async_remove_flow_progress(self, flow_id: str) -> None:
"""Remove a flow from in progress.""" """Remove a flow from in progress."""
if (flow := self._progress.pop(flow_id, None)) is None: if (flow := self._progress.pop(flow_id, None)) is None:
raise UnknownFlow raise UnknownFlow
handler = flow.handler self._async_remove_flow_from_index(flow)
self._handler_progress_index[handler].remove(flow.flow_id)
if not self._handler_progress_index[handler]:
del self._handler_progress_index[handler]
try: try:
flow.async_remove() flow.async_remove()
except Exception as err: # pylint: disable=broad-except except Exception as err: # pylint: disable=broad-except

View File

@ -803,3 +803,138 @@ async def test_goes_unavailable_connectable_only_and_recovers(
unsetup_connectable_scanner_2() unsetup_connectable_scanner_2()
cancel_not_connectable_scanner() cancel_not_connectable_scanner()
unsetup_not_connectable_scanner() unsetup_not_connectable_scanner()
async def test_goes_unavailable_dismisses_discovery(
hass: HomeAssistant, mock_bluetooth_adapters: None
) -> None:
"""Test that unavailable will dismiss any active discoveries."""
assert await async_setup_component(hass, bluetooth.DOMAIN, {})
await hass.async_block_till_done()
assert async_scanner_count(hass, connectable=False) == 0
switchbot_device_non_connectable = BLEDevice(
"44:44:33:11:23:45",
"wohand",
{},
rssi=-100,
)
switchbot_device_adv = generate_advertisement_data(
local_name="wohand",
service_uuids=["050a021a-0000-1000-8000-00805f9b34fb"],
service_data={"050a021a-0000-1000-8000-00805f9b34fb": b"\n\xff"},
manufacturer_data={1: b"\x01"},
rssi=-100,
)
callbacks = []
def _fake_subscriber(
service_info: BluetoothServiceInfo,
change: BluetoothChange,
) -> None:
"""Fake subscriber for the BleakScanner."""
callbacks.append((service_info, change))
cancel = bluetooth.async_register_callback(
hass,
_fake_subscriber,
{"address": "44:44:33:11:23:45", "connectable": False},
BluetoothScanningMode.ACTIVE,
)
class FakeScanner(BaseHaRemoteScanner):
def inject_advertisement(
self, device: BLEDevice, advertisement_data: AdvertisementData
) -> None:
"""Inject an advertisement."""
self._async_on_advertisement(
device.address,
advertisement_data.rssi,
device.name,
advertisement_data.service_uuids,
advertisement_data.service_data,
advertisement_data.manufacturer_data,
advertisement_data.tx_power,
{"scanner_specific_data": "test"},
)
def clear_all_devices(self) -> None:
"""Clear all devices."""
self._discovered_device_advertisement_datas.clear()
self._discovered_device_timestamps.clear()
new_info_callback = async_get_advertisement_callback(hass)
connector = (
HaBluetoothConnector(MockBleakClient, "mock_bleak_client", lambda: False),
)
non_connectable_scanner = FakeScanner(
hass,
"connectable",
"connectable",
new_info_callback,
connector,
False,
)
unsetup_connectable_scanner = non_connectable_scanner.async_setup()
cancel_connectable_scanner = _get_manager().async_register_scanner(
non_connectable_scanner, True
)
non_connectable_scanner.inject_advertisement(
switchbot_device_non_connectable, switchbot_device_adv
)
assert async_ble_device_from_address(hass, "44:44:33:11:23:45", False) is not None
assert async_scanner_count(hass, connectable=True) == 1
assert len(callbacks) == 1
assert (
"44:44:33:11:23:45"
in non_connectable_scanner.discovered_devices_and_advertisement_data
)
unavailable_callbacks: list[BluetoothServiceInfoBleak] = []
@callback
def _unavailable_callback(service_info: BluetoothServiceInfoBleak) -> None:
"""Wrong device unavailable callback."""
nonlocal unavailable_callbacks
unavailable_callbacks.append(service_info.address)
cancel_unavailable = async_track_unavailable(
hass,
_unavailable_callback,
switchbot_device_non_connectable.address,
connectable=False,
)
assert async_scanner_count(hass, connectable=False) == 1
non_connectable_scanner.clear_all_devices()
assert (
"44:44:33:11:23:45"
not in non_connectable_scanner.discovered_devices_and_advertisement_data
)
monotonic_now = time.monotonic()
with patch.object(
hass.config_entries.flow,
"async_progress_by_init_data_type",
return_value=[{"flow_id": "mock_flow_id"}],
) as mock_async_progress_by_init_data_type, patch.object(
hass.config_entries.flow, "async_abort"
) as mock_async_abort, patch(
"homeassistant.components.bluetooth.manager.MONOTONIC_TIME",
return_value=monotonic_now + FALLBACK_MAXIMUM_STALE_ADVERTISEMENT_SECONDS,
):
async_fire_time_changed(
hass, dt_util.utcnow() + timedelta(seconds=UNAVAILABLE_TRACK_SECONDS)
)
await hass.async_block_till_done()
assert "44:44:33:11:23:45" in unavailable_callbacks
assert len(mock_async_progress_by_init_data_type.mock_calls) == 1
assert mock_async_abort.mock_calls[0][1][0] == "mock_flow_id"
cancel_unavailable()
cancel()
unsetup_connectable_scanner()
cancel_connectable_scanner()

View File

@ -784,3 +784,76 @@ async def test_ipv4_does_additional_search_for_sonos(
), ),
) )
assert ssdp_listener.async_search.call_args[1] == {} assert ssdp_listener.async_search.call_args[1] == {}
@pytest.mark.usefixtures("mock_get_source_ip")
@patch(
"homeassistant.components.ssdp.async_get_ssdp",
return_value={"mock-domain": [{"deviceType": "Paulus"}]},
)
async def test_flow_dismiss_on_byebye(
mock_get_ssdp,
hass: HomeAssistant,
aioclient_mock: AiohttpClientMocker,
mock_flow_init,
) -> None:
"""Test config flow is only started for alive devices."""
aioclient_mock.get(
"http://1.1.1.1",
text="""
<root>
<device>
<deviceType>Paulus</deviceType>
</device>
</root>
""",
)
ssdp_listener = await init_ssdp_component(hass)
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
await hass.async_block_till_done()
# Search should start a flow
mock_ssdp_search_response = _ssdp_headers(
{
"st": "mock-st",
"location": "http://1.1.1.1",
"usn": "uuid:mock-udn::mock-st",
}
)
ssdp_listener._on_search(mock_ssdp_search_response)
await hass.async_block_till_done()
mock_flow_init.assert_awaited_once_with(
"mock-domain", context={"source": config_entries.SOURCE_SSDP}, data=ANY
)
# ssdp:alive advertisement should start a flow
mock_flow_init.reset_mock()
mock_ssdp_advertisement = _ssdp_headers(
{
"location": "http://1.1.1.1",
"usn": "uuid:mock-udn::mock-st",
"nt": "upnp:rootdevice",
"nts": "ssdp:alive",
}
)
ssdp_listener._on_alive(mock_ssdp_advertisement)
await hass.async_block_till_done()
mock_flow_init.assert_awaited_once_with(
"mock-domain", context={"source": config_entries.SOURCE_SSDP}, data=ANY
)
mock_ssdp_advertisement["nts"] = "ssdp:byebye"
# ssdp:byebye advertisement should dismiss existing flows
with patch.object(
hass.config_entries.flow,
"async_progress_by_init_data_type",
return_value=[{"flow_id": "mock_flow_id"}],
) as mock_async_progress_by_init_data_type, patch.object(
hass.config_entries.flow, "async_abort"
) as mock_async_abort:
ssdp_listener._on_byebye(mock_ssdp_advertisement)
await hass.async_block_till_done()
assert len(mock_async_progress_by_init_data_type.mock_calls) == 1
assert mock_async_abort.mock_calls[0][1][0] == "mock_flow_id"

View File

@ -1339,3 +1339,47 @@ async def test_start_with_frontend(
await hass.async_block_till_done() await hass.async_block_till_done()
mock_async_zeroconf.async_register_service.assert_called_once() mock_async_zeroconf.async_register_service.assert_called_once()
async def test_zeroconf_removed(hass: HomeAssistant, mock_async_zeroconf: None) -> None:
"""Test we dismiss flows when a PTR record is removed."""
def _device_removed_mock(ipv6, zeroconf, services, handlers):
"""Call service update handler."""
handlers[0](
zeroconf,
"_http._tcp.local.",
"Shelly108._http._tcp.local.",
ServiceStateChange.Removed,
)
with patch.dict(
zc_gen.ZEROCONF,
{
"_http._tcp.local.": [
{
"domain": "shelly",
"name": "shelly*",
}
]
},
clear=True,
), patch.object(
hass.config_entries.flow,
"async_progress_by_init_data_type",
return_value=[{"flow_id": "mock_flow_id"}],
) as mock_async_progress_by_init_data_type, patch.object(
hass.config_entries.flow, "async_abort"
) as mock_async_abort, patch.object(
zeroconf, "HaAsyncServiceBrowser", side_effect=_device_removed_mock
) as mock_service_browser, patch(
"homeassistant.components.zeroconf.AsyncServiceInfo",
side_effect=get_zeroconf_info_mock("FFAADDCC11DD"),
):
assert await async_setup_component(hass, zeroconf.DOMAIN, {zeroconf.DOMAIN: {}})
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
await hass.async_block_till_done()
assert len(mock_service_browser.mock_calls) == 1
assert len(mock_async_progress_by_init_data_type.mock_calls) == 1
assert mock_async_abort.mock_calls[0][1][0] == "mock_flow_id"

View File

@ -1,4 +1,5 @@
"""Test the flow classes.""" """Test the flow classes."""
import dataclasses
import logging import logging
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
@ -73,7 +74,7 @@ async def test_configure_reuses_handler_instance(manager):
assert len(manager.mock_created_entries) == 0 assert len(manager.mock_created_entries) == 0
async def test_configure_two_steps(manager): async def test_configure_two_steps(manager: data_entry_flow.FlowManager) -> None:
"""Test that we reuse instances.""" """Test that we reuse instances."""
@manager.mock_reg_handler("test") @manager.mock_reg_handler("test")
@ -82,7 +83,6 @@ async def test_configure_two_steps(manager):
async def async_step_first(self, user_input=None): async def async_step_first(self, user_input=None):
if user_input is not None: if user_input is not None:
self.init_data = user_input
return await self.async_step_second() return await self.async_step_second()
return self.async_show_form(step_id="first", data_schema=vol.Schema([str])) return self.async_show_form(step_id="first", data_schema=vol.Schema([str]))
@ -93,12 +93,13 @@ async def test_configure_two_steps(manager):
) )
return self.async_show_form(step_id="second", data_schema=vol.Schema([str])) return self.async_show_form(step_id="second", data_schema=vol.Schema([str]))
form = await manager.async_init("test", context={"init_step": "first"}) form = await manager.async_init(
"test", context={"init_step": "first"}, data=["INIT-DATA"]
)
with pytest.raises(vol.Invalid): with pytest.raises(vol.Invalid):
form = await manager.async_configure(form["flow_id"], "INCORRECT-DATA") form = await manager.async_configure(form["flow_id"], "INCORRECT-DATA")
form = await manager.async_configure(form["flow_id"], ["INIT-DATA"])
form = await manager.async_configure(form["flow_id"], ["SECOND-DATA"]) form = await manager.async_configure(form["flow_id"], ["SECOND-DATA"])
assert form["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY assert form["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY
assert len(manager.async_progress()) == 0 assert len(manager.async_progress()) == 0
@ -553,3 +554,102 @@ async def test_show_menu(hass, manager, menu_options):
) )
assert result["type"] == data_entry_flow.FlowResultType.FORM assert result["type"] == data_entry_flow.FlowResultType.FORM
assert result["step_id"] == "target1" assert result["step_id"] == "target1"
async def test_find_flows_by_init_data_type(
manager: data_entry_flow.FlowManager,
) -> None:
"""Test we can find flows by init data type."""
@dataclasses.dataclass
class BluetoothDiscoveryData:
"""Bluetooth Discovery data."""
address: str
@dataclasses.dataclass
class WiFiDiscoveryData:
"""WiFi Discovery data."""
address: str
@manager.mock_reg_handler("test")
class TestFlow(data_entry_flow.FlowHandler):
VERSION = 1
async def async_step_first(self, user_input=None):
if user_input is not None:
return await self.async_step_second()
return self.async_show_form(step_id="first", data_schema=vol.Schema([str]))
async def async_step_second(self, user_input=None):
if user_input is not None:
return self.async_create_entry(
title="Test Entry",
data={"init": self.init_data, "user": user_input},
)
return self.async_show_form(step_id="second", data_schema=vol.Schema([str]))
bluetooth_data = BluetoothDiscoveryData("aa:bb:cc:dd:ee:ff")
wifi_data = WiFiDiscoveryData("host")
bluetooth_form = await manager.async_init(
"test", context={"init_step": "first"}, data=bluetooth_data
)
await manager.async_init("test", context={"init_step": "first"}, data=wifi_data)
assert (
len(
manager.async_progress_by_init_data_type(
BluetoothDiscoveryData, lambda data: True
)
)
) == 1
assert (
len(
manager.async_progress_by_init_data_type(
BluetoothDiscoveryData,
lambda data: bool(data.address == "aa:bb:cc:dd:ee:ff"),
)
)
) == 1
assert (
len(
manager.async_progress_by_init_data_type(
BluetoothDiscoveryData, lambda data: bool(data.address == "not it")
)
)
) == 0
wifi_flows = manager.async_progress_by_init_data_type(
WiFiDiscoveryData, lambda data: True
)
assert len(wifi_flows) == 1
bluetooth_result = await manager.async_configure(
bluetooth_form["flow_id"], ["SECOND-DATA"]
)
assert bluetooth_result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY
assert len(manager.async_progress()) == 1
assert len(manager.mock_created_entries) == 1
result = manager.mock_created_entries[0]
assert result["handler"] == "test"
assert result["data"] == {"init": bluetooth_data, "user": ["SECOND-DATA"]}
bluetooth_flows = manager.async_progress_by_init_data_type(
BluetoothDiscoveryData, lambda data: True
)
assert len(bluetooth_flows) == 0
wifi_flows = manager.async_progress_by_init_data_type(
WiFiDiscoveryData, lambda data: True
)
assert len(wifi_flows) == 1
manager.async_abort(wifi_flows[0]["flow_id"])
wifi_flows = manager.async_progress_by_init_data_type(
WiFiDiscoveryData, lambda data: True
)
assert len(wifi_flows) == 0
assert len(manager.async_progress()) == 0