mirror of
https://github.com/home-assistant/core.git
synced 2025-07-19 11:17:21 +00:00
Dismiss discoveries when the underlying device disappears (#88340)
* Implement discovery removals Bluetooth, HomeKit, SSDP, and Zeroconf now implement dismissing discoveries when the underlying discovered device disappears * cover * add zeroconf test * cover * cover bluetooth * fix rediscover
This commit is contained in:
parent
710b250c1d
commit
331102e592
@ -315,6 +315,8 @@ class BluetoothManager:
|
|||||||
# the device from all the interval tracking since it is no longer
|
# the device from all the interval tracking since it is no longer
|
||||||
# available for both connectable and non-connectable
|
# available for both connectable and non-connectable
|
||||||
tracker.async_remove_address(address)
|
tracker.async_remove_address(address)
|
||||||
|
self._integration_matcher.async_clear_address(address)
|
||||||
|
self._async_dismiss_discoveries(address)
|
||||||
|
|
||||||
service_info = history.pop(address)
|
service_info = history.pop(address)
|
||||||
|
|
||||||
@ -327,6 +329,14 @@ class BluetoothManager:
|
|||||||
except Exception: # pylint: disable=broad-except
|
except Exception: # pylint: disable=broad-except
|
||||||
_LOGGER.exception("Error in unavailable callback")
|
_LOGGER.exception("Error in unavailable callback")
|
||||||
|
|
||||||
|
def _async_dismiss_discoveries(self, address: str) -> None:
|
||||||
|
"""Dismiss all discoveries for the given address."""
|
||||||
|
for flow in self.hass.config_entries.flow.async_progress_by_init_data_type(
|
||||||
|
BluetoothServiceInfoBleak,
|
||||||
|
lambda service_info: bool(service_info.address == address),
|
||||||
|
):
|
||||||
|
self.hass.config_entries.flow.async_abort(flow["flow_id"])
|
||||||
|
|
||||||
def _prefer_previous_adv_from_different_source(
|
def _prefer_previous_adv_from_different_source(
|
||||||
self,
|
self,
|
||||||
old: BluetoothServiceInfoBleak,
|
old: BluetoothServiceInfoBleak,
|
||||||
|
@ -518,7 +518,11 @@ class Scanner:
|
|||||||
CaseInsensitiveDict(combined_headers.as_dict(), **info_desc)
|
CaseInsensitiveDict(combined_headers.as_dict(), **info_desc)
|
||||||
)
|
)
|
||||||
|
|
||||||
if not callbacks and not matching_domains:
|
if (
|
||||||
|
not callbacks
|
||||||
|
and not matching_domains
|
||||||
|
and source != SsdpSource.ADVERTISEMENT_BYEBYE
|
||||||
|
):
|
||||||
return
|
return
|
||||||
|
|
||||||
discovery_info = discovery_info_from_headers_and_description(
|
discovery_info = discovery_info_from_headers_and_description(
|
||||||
@ -534,6 +538,7 @@ class Scanner:
|
|||||||
|
|
||||||
# Config flows should only be created for alive/update messages from alive devices
|
# Config flows should only be created for alive/update messages from alive devices
|
||||||
if source == SsdpSource.ADVERTISEMENT_BYEBYE:
|
if source == SsdpSource.ADVERTISEMENT_BYEBYE:
|
||||||
|
self._async_dismiss_discoveries(discovery_info)
|
||||||
return
|
return
|
||||||
|
|
||||||
_LOGGER.debug("Discovery info: %s", discovery_info)
|
_LOGGER.debug("Discovery info: %s", discovery_info)
|
||||||
@ -548,6 +553,19 @@ class Scanner:
|
|||||||
discovery_info,
|
discovery_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _async_dismiss_discoveries(
|
||||||
|
self, byebye_discovery_info: SsdpServiceInfo
|
||||||
|
) -> None:
|
||||||
|
"""Dismiss all discoveries for the given address."""
|
||||||
|
for flow in self.hass.config_entries.flow.async_progress_by_init_data_type(
|
||||||
|
SsdpServiceInfo,
|
||||||
|
lambda service_info: bool(
|
||||||
|
service_info.ssdp_st == byebye_discovery_info.ssdp_st
|
||||||
|
and service_info.ssdp_location == byebye_discovery_info.ssdp_location
|
||||||
|
),
|
||||||
|
):
|
||||||
|
self.hass.config_entries.flow.async_abort(flow["flow_id"])
|
||||||
|
|
||||||
async def _async_get_description_dict(
|
async def _async_get_description_dict(
|
||||||
self, location: str | None
|
self, location: str | None
|
||||||
) -> Mapping[str, str]:
|
) -> Mapping[str, str]:
|
||||||
|
@ -378,6 +378,14 @@ class ZeroconfDiscovery:
|
|||||||
if self.async_service_browser:
|
if self.async_service_browser:
|
||||||
await self.async_service_browser.async_cancel()
|
await self.async_service_browser.async_cancel()
|
||||||
|
|
||||||
|
def _async_dismiss_discoveries(self, name: str) -> None:
|
||||||
|
"""Dismiss all discoveries for the given name."""
|
||||||
|
for flow in self.hass.config_entries.flow.async_progress_by_init_data_type(
|
||||||
|
ZeroconfServiceInfo,
|
||||||
|
lambda service_info: bool(service_info.name == name),
|
||||||
|
):
|
||||||
|
self.hass.config_entries.flow.async_abort(flow["flow_id"])
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_service_update(
|
def async_service_update(
|
||||||
self,
|
self,
|
||||||
@ -395,6 +403,7 @@ class ZeroconfDiscovery:
|
|||||||
)
|
)
|
||||||
|
|
||||||
if state_change == ServiceStateChange.Removed:
|
if state_change == ServiceStateChange.Removed:
|
||||||
|
self._async_dismiss_discoveries(name)
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import abc
|
import abc
|
||||||
from collections.abc import Iterable, Mapping
|
from collections.abc import Callable, Iterable, Mapping
|
||||||
import copy
|
import copy
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import logging
|
import logging
|
||||||
@ -138,6 +138,7 @@ class FlowManager(abc.ABC):
|
|||||||
self.hass = hass
|
self.hass = hass
|
||||||
self._progress: dict[str, FlowHandler] = {}
|
self._progress: dict[str, FlowHandler] = {}
|
||||||
self._handler_progress_index: dict[str, set[str]] = {}
|
self._handler_progress_index: dict[str, set[str]] = {}
|
||||||
|
self._init_data_process_index: dict[type, set[str]] = {}
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
async def async_create_flow(
|
async def async_create_flow(
|
||||||
@ -198,6 +199,23 @@ class FlowManager(abc.ABC):
|
|||||||
self._async_progress_by_handler(handler), include_uninitialized
|
self._async_progress_by_handler(handler), include_uninitialized
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_progress_by_init_data_type(
|
||||||
|
self,
|
||||||
|
init_data_type: type,
|
||||||
|
matcher: Callable[[Any], bool],
|
||||||
|
include_uninitialized: bool = False,
|
||||||
|
) -> list[FlowResult]:
|
||||||
|
"""Return flows in progress init matching by data type as a partial FlowResult."""
|
||||||
|
return _async_flow_handler_to_flow_result(
|
||||||
|
(
|
||||||
|
self._progress[flow_id]
|
||||||
|
for flow_id in self._init_data_process_index.get(init_data_type, {})
|
||||||
|
if matcher(self._progress[flow_id].init_data)
|
||||||
|
),
|
||||||
|
include_uninitialized,
|
||||||
|
)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _async_progress_by_handler(self, handler: str) -> list[FlowHandler]:
|
def _async_progress_by_handler(self, handler: str) -> list[FlowHandler]:
|
||||||
"""Return the flows in progress by handler."""
|
"""Return the flows in progress by handler."""
|
||||||
@ -301,19 +319,33 @@ class FlowManager(abc.ABC):
|
|||||||
@callback
|
@callback
|
||||||
def _async_add_flow_progress(self, flow: FlowHandler) -> None:
|
def _async_add_flow_progress(self, flow: FlowHandler) -> None:
|
||||||
"""Add a flow to in progress."""
|
"""Add a flow to in progress."""
|
||||||
|
if flow.init_data is not None:
|
||||||
|
init_data_type = type(flow.init_data)
|
||||||
|
self._init_data_process_index.setdefault(init_data_type, set()).add(
|
||||||
|
flow.flow_id
|
||||||
|
)
|
||||||
self._progress[flow.flow_id] = flow
|
self._progress[flow.flow_id] = flow
|
||||||
self._handler_progress_index.setdefault(flow.handler, set()).add(flow.flow_id)
|
self._handler_progress_index.setdefault(flow.handler, set()).add(flow.flow_id)
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _async_remove_flow_from_index(self, flow: FlowHandler) -> None:
|
||||||
|
"""Remove a flow from in progress."""
|
||||||
|
if flow.init_data is not None:
|
||||||
|
init_data_type = type(flow.init_data)
|
||||||
|
self._init_data_process_index[init_data_type].remove(flow.flow_id)
|
||||||
|
if not self._init_data_process_index[init_data_type]:
|
||||||
|
del self._init_data_process_index[init_data_type]
|
||||||
|
handler = flow.handler
|
||||||
|
self._handler_progress_index[handler].remove(flow.flow_id)
|
||||||
|
if not self._handler_progress_index[handler]:
|
||||||
|
del self._handler_progress_index[handler]
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _async_remove_flow_progress(self, flow_id: str) -> None:
|
def _async_remove_flow_progress(self, flow_id: str) -> None:
|
||||||
"""Remove a flow from in progress."""
|
"""Remove a flow from in progress."""
|
||||||
if (flow := self._progress.pop(flow_id, None)) is None:
|
if (flow := self._progress.pop(flow_id, None)) is None:
|
||||||
raise UnknownFlow
|
raise UnknownFlow
|
||||||
handler = flow.handler
|
self._async_remove_flow_from_index(flow)
|
||||||
self._handler_progress_index[handler].remove(flow.flow_id)
|
|
||||||
if not self._handler_progress_index[handler]:
|
|
||||||
del self._handler_progress_index[handler]
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
flow.async_remove()
|
flow.async_remove()
|
||||||
except Exception as err: # pylint: disable=broad-except
|
except Exception as err: # pylint: disable=broad-except
|
||||||
|
@ -803,3 +803,138 @@ async def test_goes_unavailable_connectable_only_and_recovers(
|
|||||||
unsetup_connectable_scanner_2()
|
unsetup_connectable_scanner_2()
|
||||||
cancel_not_connectable_scanner()
|
cancel_not_connectable_scanner()
|
||||||
unsetup_not_connectable_scanner()
|
unsetup_not_connectable_scanner()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_goes_unavailable_dismisses_discovery(
|
||||||
|
hass: HomeAssistant, mock_bluetooth_adapters: None
|
||||||
|
) -> None:
|
||||||
|
"""Test that unavailable will dismiss any active discoveries."""
|
||||||
|
assert await async_setup_component(hass, bluetooth.DOMAIN, {})
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert async_scanner_count(hass, connectable=False) == 0
|
||||||
|
switchbot_device_non_connectable = BLEDevice(
|
||||||
|
"44:44:33:11:23:45",
|
||||||
|
"wohand",
|
||||||
|
{},
|
||||||
|
rssi=-100,
|
||||||
|
)
|
||||||
|
switchbot_device_adv = generate_advertisement_data(
|
||||||
|
local_name="wohand",
|
||||||
|
service_uuids=["050a021a-0000-1000-8000-00805f9b34fb"],
|
||||||
|
service_data={"050a021a-0000-1000-8000-00805f9b34fb": b"\n\xff"},
|
||||||
|
manufacturer_data={1: b"\x01"},
|
||||||
|
rssi=-100,
|
||||||
|
)
|
||||||
|
callbacks = []
|
||||||
|
|
||||||
|
def _fake_subscriber(
|
||||||
|
service_info: BluetoothServiceInfo,
|
||||||
|
change: BluetoothChange,
|
||||||
|
) -> None:
|
||||||
|
"""Fake subscriber for the BleakScanner."""
|
||||||
|
callbacks.append((service_info, change))
|
||||||
|
|
||||||
|
cancel = bluetooth.async_register_callback(
|
||||||
|
hass,
|
||||||
|
_fake_subscriber,
|
||||||
|
{"address": "44:44:33:11:23:45", "connectable": False},
|
||||||
|
BluetoothScanningMode.ACTIVE,
|
||||||
|
)
|
||||||
|
|
||||||
|
class FakeScanner(BaseHaRemoteScanner):
|
||||||
|
def inject_advertisement(
|
||||||
|
self, device: BLEDevice, advertisement_data: AdvertisementData
|
||||||
|
) -> None:
|
||||||
|
"""Inject an advertisement."""
|
||||||
|
self._async_on_advertisement(
|
||||||
|
device.address,
|
||||||
|
advertisement_data.rssi,
|
||||||
|
device.name,
|
||||||
|
advertisement_data.service_uuids,
|
||||||
|
advertisement_data.service_data,
|
||||||
|
advertisement_data.manufacturer_data,
|
||||||
|
advertisement_data.tx_power,
|
||||||
|
{"scanner_specific_data": "test"},
|
||||||
|
)
|
||||||
|
|
||||||
|
def clear_all_devices(self) -> None:
|
||||||
|
"""Clear all devices."""
|
||||||
|
self._discovered_device_advertisement_datas.clear()
|
||||||
|
self._discovered_device_timestamps.clear()
|
||||||
|
|
||||||
|
new_info_callback = async_get_advertisement_callback(hass)
|
||||||
|
connector = (
|
||||||
|
HaBluetoothConnector(MockBleakClient, "mock_bleak_client", lambda: False),
|
||||||
|
)
|
||||||
|
non_connectable_scanner = FakeScanner(
|
||||||
|
hass,
|
||||||
|
"connectable",
|
||||||
|
"connectable",
|
||||||
|
new_info_callback,
|
||||||
|
connector,
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
unsetup_connectable_scanner = non_connectable_scanner.async_setup()
|
||||||
|
cancel_connectable_scanner = _get_manager().async_register_scanner(
|
||||||
|
non_connectable_scanner, True
|
||||||
|
)
|
||||||
|
non_connectable_scanner.inject_advertisement(
|
||||||
|
switchbot_device_non_connectable, switchbot_device_adv
|
||||||
|
)
|
||||||
|
assert async_ble_device_from_address(hass, "44:44:33:11:23:45", False) is not None
|
||||||
|
assert async_scanner_count(hass, connectable=True) == 1
|
||||||
|
assert len(callbacks) == 1
|
||||||
|
|
||||||
|
assert (
|
||||||
|
"44:44:33:11:23:45"
|
||||||
|
in non_connectable_scanner.discovered_devices_and_advertisement_data
|
||||||
|
)
|
||||||
|
|
||||||
|
unavailable_callbacks: list[BluetoothServiceInfoBleak] = []
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _unavailable_callback(service_info: BluetoothServiceInfoBleak) -> None:
|
||||||
|
"""Wrong device unavailable callback."""
|
||||||
|
nonlocal unavailable_callbacks
|
||||||
|
unavailable_callbacks.append(service_info.address)
|
||||||
|
|
||||||
|
cancel_unavailable = async_track_unavailable(
|
||||||
|
hass,
|
||||||
|
_unavailable_callback,
|
||||||
|
switchbot_device_non_connectable.address,
|
||||||
|
connectable=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert async_scanner_count(hass, connectable=False) == 1
|
||||||
|
|
||||||
|
non_connectable_scanner.clear_all_devices()
|
||||||
|
assert (
|
||||||
|
"44:44:33:11:23:45"
|
||||||
|
not in non_connectable_scanner.discovered_devices_and_advertisement_data
|
||||||
|
)
|
||||||
|
monotonic_now = time.monotonic()
|
||||||
|
with patch.object(
|
||||||
|
hass.config_entries.flow,
|
||||||
|
"async_progress_by_init_data_type",
|
||||||
|
return_value=[{"flow_id": "mock_flow_id"}],
|
||||||
|
) as mock_async_progress_by_init_data_type, patch.object(
|
||||||
|
hass.config_entries.flow, "async_abort"
|
||||||
|
) as mock_async_abort, patch(
|
||||||
|
"homeassistant.components.bluetooth.manager.MONOTONIC_TIME",
|
||||||
|
return_value=monotonic_now + FALLBACK_MAXIMUM_STALE_ADVERTISEMENT_SECONDS,
|
||||||
|
):
|
||||||
|
async_fire_time_changed(
|
||||||
|
hass, dt_util.utcnow() + timedelta(seconds=UNAVAILABLE_TRACK_SECONDS)
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert "44:44:33:11:23:45" in unavailable_callbacks
|
||||||
|
|
||||||
|
assert len(mock_async_progress_by_init_data_type.mock_calls) == 1
|
||||||
|
assert mock_async_abort.mock_calls[0][1][0] == "mock_flow_id"
|
||||||
|
|
||||||
|
cancel_unavailable()
|
||||||
|
|
||||||
|
cancel()
|
||||||
|
unsetup_connectable_scanner()
|
||||||
|
cancel_connectable_scanner()
|
||||||
|
@ -784,3 +784,76 @@ async def test_ipv4_does_additional_search_for_sonos(
|
|||||||
),
|
),
|
||||||
)
|
)
|
||||||
assert ssdp_listener.async_search.call_args[1] == {}
|
assert ssdp_listener.async_search.call_args[1] == {}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.usefixtures("mock_get_source_ip")
|
||||||
|
@patch(
|
||||||
|
"homeassistant.components.ssdp.async_get_ssdp",
|
||||||
|
return_value={"mock-domain": [{"deviceType": "Paulus"}]},
|
||||||
|
)
|
||||||
|
async def test_flow_dismiss_on_byebye(
|
||||||
|
mock_get_ssdp,
|
||||||
|
hass: HomeAssistant,
|
||||||
|
aioclient_mock: AiohttpClientMocker,
|
||||||
|
mock_flow_init,
|
||||||
|
) -> None:
|
||||||
|
"""Test config flow is only started for alive devices."""
|
||||||
|
aioclient_mock.get(
|
||||||
|
"http://1.1.1.1",
|
||||||
|
text="""
|
||||||
|
<root>
|
||||||
|
<device>
|
||||||
|
<deviceType>Paulus</deviceType>
|
||||||
|
</device>
|
||||||
|
</root>
|
||||||
|
""",
|
||||||
|
)
|
||||||
|
ssdp_listener = await init_ssdp_component(hass)
|
||||||
|
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
# Search should start a flow
|
||||||
|
mock_ssdp_search_response = _ssdp_headers(
|
||||||
|
{
|
||||||
|
"st": "mock-st",
|
||||||
|
"location": "http://1.1.1.1",
|
||||||
|
"usn": "uuid:mock-udn::mock-st",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
ssdp_listener._on_search(mock_ssdp_search_response)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
mock_flow_init.assert_awaited_once_with(
|
||||||
|
"mock-domain", context={"source": config_entries.SOURCE_SSDP}, data=ANY
|
||||||
|
)
|
||||||
|
|
||||||
|
# ssdp:alive advertisement should start a flow
|
||||||
|
mock_flow_init.reset_mock()
|
||||||
|
mock_ssdp_advertisement = _ssdp_headers(
|
||||||
|
{
|
||||||
|
"location": "http://1.1.1.1",
|
||||||
|
"usn": "uuid:mock-udn::mock-st",
|
||||||
|
"nt": "upnp:rootdevice",
|
||||||
|
"nts": "ssdp:alive",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
ssdp_listener._on_alive(mock_ssdp_advertisement)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
mock_flow_init.assert_awaited_once_with(
|
||||||
|
"mock-domain", context={"source": config_entries.SOURCE_SSDP}, data=ANY
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_ssdp_advertisement["nts"] = "ssdp:byebye"
|
||||||
|
# ssdp:byebye advertisement should dismiss existing flows
|
||||||
|
with patch.object(
|
||||||
|
hass.config_entries.flow,
|
||||||
|
"async_progress_by_init_data_type",
|
||||||
|
return_value=[{"flow_id": "mock_flow_id"}],
|
||||||
|
) as mock_async_progress_by_init_data_type, patch.object(
|
||||||
|
hass.config_entries.flow, "async_abort"
|
||||||
|
) as mock_async_abort:
|
||||||
|
ssdp_listener._on_byebye(mock_ssdp_advertisement)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert len(mock_async_progress_by_init_data_type.mock_calls) == 1
|
||||||
|
assert mock_async_abort.mock_calls[0][1][0] == "mock_flow_id"
|
||||||
|
@ -1339,3 +1339,47 @@ async def test_start_with_frontend(
|
|||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
mock_async_zeroconf.async_register_service.assert_called_once()
|
mock_async_zeroconf.async_register_service.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_zeroconf_removed(hass: HomeAssistant, mock_async_zeroconf: None) -> None:
|
||||||
|
"""Test we dismiss flows when a PTR record is removed."""
|
||||||
|
|
||||||
|
def _device_removed_mock(ipv6, zeroconf, services, handlers):
|
||||||
|
"""Call service update handler."""
|
||||||
|
handlers[0](
|
||||||
|
zeroconf,
|
||||||
|
"_http._tcp.local.",
|
||||||
|
"Shelly108._http._tcp.local.",
|
||||||
|
ServiceStateChange.Removed,
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.dict(
|
||||||
|
zc_gen.ZEROCONF,
|
||||||
|
{
|
||||||
|
"_http._tcp.local.": [
|
||||||
|
{
|
||||||
|
"domain": "shelly",
|
||||||
|
"name": "shelly*",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
clear=True,
|
||||||
|
), patch.object(
|
||||||
|
hass.config_entries.flow,
|
||||||
|
"async_progress_by_init_data_type",
|
||||||
|
return_value=[{"flow_id": "mock_flow_id"}],
|
||||||
|
) as mock_async_progress_by_init_data_type, patch.object(
|
||||||
|
hass.config_entries.flow, "async_abort"
|
||||||
|
) as mock_async_abort, patch.object(
|
||||||
|
zeroconf, "HaAsyncServiceBrowser", side_effect=_device_removed_mock
|
||||||
|
) as mock_service_browser, patch(
|
||||||
|
"homeassistant.components.zeroconf.AsyncServiceInfo",
|
||||||
|
side_effect=get_zeroconf_info_mock("FFAADDCC11DD"),
|
||||||
|
):
|
||||||
|
assert await async_setup_component(hass, zeroconf.DOMAIN, {zeroconf.DOMAIN: {}})
|
||||||
|
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
assert len(mock_service_browser.mock_calls) == 1
|
||||||
|
assert len(mock_async_progress_by_init_data_type.mock_calls) == 1
|
||||||
|
assert mock_async_abort.mock_calls[0][1][0] == "mock_flow_id"
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
"""Test the flow classes."""
|
"""Test the flow classes."""
|
||||||
|
import dataclasses
|
||||||
import logging
|
import logging
|
||||||
from unittest.mock import Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
@ -73,7 +74,7 @@ async def test_configure_reuses_handler_instance(manager):
|
|||||||
assert len(manager.mock_created_entries) == 0
|
assert len(manager.mock_created_entries) == 0
|
||||||
|
|
||||||
|
|
||||||
async def test_configure_two_steps(manager):
|
async def test_configure_two_steps(manager: data_entry_flow.FlowManager) -> None:
|
||||||
"""Test that we reuse instances."""
|
"""Test that we reuse instances."""
|
||||||
|
|
||||||
@manager.mock_reg_handler("test")
|
@manager.mock_reg_handler("test")
|
||||||
@ -82,7 +83,6 @@ async def test_configure_two_steps(manager):
|
|||||||
|
|
||||||
async def async_step_first(self, user_input=None):
|
async def async_step_first(self, user_input=None):
|
||||||
if user_input is not None:
|
if user_input is not None:
|
||||||
self.init_data = user_input
|
|
||||||
return await self.async_step_second()
|
return await self.async_step_second()
|
||||||
return self.async_show_form(step_id="first", data_schema=vol.Schema([str]))
|
return self.async_show_form(step_id="first", data_schema=vol.Schema([str]))
|
||||||
|
|
||||||
@ -93,12 +93,13 @@ async def test_configure_two_steps(manager):
|
|||||||
)
|
)
|
||||||
return self.async_show_form(step_id="second", data_schema=vol.Schema([str]))
|
return self.async_show_form(step_id="second", data_schema=vol.Schema([str]))
|
||||||
|
|
||||||
form = await manager.async_init("test", context={"init_step": "first"})
|
form = await manager.async_init(
|
||||||
|
"test", context={"init_step": "first"}, data=["INIT-DATA"]
|
||||||
|
)
|
||||||
|
|
||||||
with pytest.raises(vol.Invalid):
|
with pytest.raises(vol.Invalid):
|
||||||
form = await manager.async_configure(form["flow_id"], "INCORRECT-DATA")
|
form = await manager.async_configure(form["flow_id"], "INCORRECT-DATA")
|
||||||
|
|
||||||
form = await manager.async_configure(form["flow_id"], ["INIT-DATA"])
|
|
||||||
form = await manager.async_configure(form["flow_id"], ["SECOND-DATA"])
|
form = await manager.async_configure(form["flow_id"], ["SECOND-DATA"])
|
||||||
assert form["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY
|
assert form["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY
|
||||||
assert len(manager.async_progress()) == 0
|
assert len(manager.async_progress()) == 0
|
||||||
@ -553,3 +554,102 @@ async def test_show_menu(hass, manager, menu_options):
|
|||||||
)
|
)
|
||||||
assert result["type"] == data_entry_flow.FlowResultType.FORM
|
assert result["type"] == data_entry_flow.FlowResultType.FORM
|
||||||
assert result["step_id"] == "target1"
|
assert result["step_id"] == "target1"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_find_flows_by_init_data_type(
|
||||||
|
manager: data_entry_flow.FlowManager,
|
||||||
|
) -> None:
|
||||||
|
"""Test we can find flows by init data type."""
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class BluetoothDiscoveryData:
|
||||||
|
"""Bluetooth Discovery data."""
|
||||||
|
|
||||||
|
address: str
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class WiFiDiscoveryData:
|
||||||
|
"""WiFi Discovery data."""
|
||||||
|
|
||||||
|
address: str
|
||||||
|
|
||||||
|
@manager.mock_reg_handler("test")
|
||||||
|
class TestFlow(data_entry_flow.FlowHandler):
|
||||||
|
VERSION = 1
|
||||||
|
|
||||||
|
async def async_step_first(self, user_input=None):
|
||||||
|
if user_input is not None:
|
||||||
|
return await self.async_step_second()
|
||||||
|
return self.async_show_form(step_id="first", data_schema=vol.Schema([str]))
|
||||||
|
|
||||||
|
async def async_step_second(self, user_input=None):
|
||||||
|
if user_input is not None:
|
||||||
|
return self.async_create_entry(
|
||||||
|
title="Test Entry",
|
||||||
|
data={"init": self.init_data, "user": user_input},
|
||||||
|
)
|
||||||
|
return self.async_show_form(step_id="second", data_schema=vol.Schema([str]))
|
||||||
|
|
||||||
|
bluetooth_data = BluetoothDiscoveryData("aa:bb:cc:dd:ee:ff")
|
||||||
|
wifi_data = WiFiDiscoveryData("host")
|
||||||
|
|
||||||
|
bluetooth_form = await manager.async_init(
|
||||||
|
"test", context={"init_step": "first"}, data=bluetooth_data
|
||||||
|
)
|
||||||
|
await manager.async_init("test", context={"init_step": "first"}, data=wifi_data)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
len(
|
||||||
|
manager.async_progress_by_init_data_type(
|
||||||
|
BluetoothDiscoveryData, lambda data: True
|
||||||
|
)
|
||||||
|
)
|
||||||
|
) == 1
|
||||||
|
assert (
|
||||||
|
len(
|
||||||
|
manager.async_progress_by_init_data_type(
|
||||||
|
BluetoothDiscoveryData,
|
||||||
|
lambda data: bool(data.address == "aa:bb:cc:dd:ee:ff"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
) == 1
|
||||||
|
assert (
|
||||||
|
len(
|
||||||
|
manager.async_progress_by_init_data_type(
|
||||||
|
BluetoothDiscoveryData, lambda data: bool(data.address == "not it")
|
||||||
|
)
|
||||||
|
)
|
||||||
|
) == 0
|
||||||
|
|
||||||
|
wifi_flows = manager.async_progress_by_init_data_type(
|
||||||
|
WiFiDiscoveryData, lambda data: True
|
||||||
|
)
|
||||||
|
assert len(wifi_flows) == 1
|
||||||
|
|
||||||
|
bluetooth_result = await manager.async_configure(
|
||||||
|
bluetooth_form["flow_id"], ["SECOND-DATA"]
|
||||||
|
)
|
||||||
|
assert bluetooth_result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY
|
||||||
|
assert len(manager.async_progress()) == 1
|
||||||
|
assert len(manager.mock_created_entries) == 1
|
||||||
|
result = manager.mock_created_entries[0]
|
||||||
|
assert result["handler"] == "test"
|
||||||
|
assert result["data"] == {"init": bluetooth_data, "user": ["SECOND-DATA"]}
|
||||||
|
|
||||||
|
bluetooth_flows = manager.async_progress_by_init_data_type(
|
||||||
|
BluetoothDiscoveryData, lambda data: True
|
||||||
|
)
|
||||||
|
assert len(bluetooth_flows) == 0
|
||||||
|
|
||||||
|
wifi_flows = manager.async_progress_by_init_data_type(
|
||||||
|
WiFiDiscoveryData, lambda data: True
|
||||||
|
)
|
||||||
|
assert len(wifi_flows) == 1
|
||||||
|
|
||||||
|
manager.async_abort(wifi_flows[0]["flow_id"])
|
||||||
|
|
||||||
|
wifi_flows = manager.async_progress_by_init_data_type(
|
||||||
|
WiFiDiscoveryData, lambda data: True
|
||||||
|
)
|
||||||
|
assert len(wifi_flows) == 0
|
||||||
|
assert len(manager.async_progress()) == 0
|
||||||
|
Loading…
x
Reference in New Issue
Block a user