diff --git a/homeassistant/components/dhcp/__init__.py b/homeassistant/components/dhcp/__init__.py index 3e4fd8fec01..d52b30ccfb2 100644 --- a/homeassistant/components/dhcp/__init__.py +++ b/homeassistant/components/dhcp/__init__.py @@ -1,6 +1,5 @@ """The dhcp integration.""" -from abc import abstractmethod from datetime import timedelta import fnmatch from ipaddress import ip_address as make_ip_address @@ -17,6 +16,7 @@ from aiodiscover.discovery import ( from scapy.config import conf from scapy.error import Scapy_Exception +from homeassistant import config_entries from homeassistant.components.device_tracker.const import ( ATTR_HOST_NAME, ATTR_IP, @@ -31,6 +31,7 @@ from homeassistant.const import ( STATE_HOME, ) from homeassistant.core import Event, HomeAssistant, State, callback +from homeassistant.helpers import discovery_flow from homeassistant.helpers.device_registry import format_mac from homeassistant.helpers.event import ( async_track_state_added_domain, @@ -38,10 +39,9 @@ from homeassistant.helpers.event import ( ) from homeassistant.helpers.typing import ConfigType from homeassistant.loader import async_get_dhcp +from homeassistant.util.async_ import run_callback_threadsafe from homeassistant.util.network import is_invalid, is_link_local, is_loopback -from .const import DOMAIN - FILTER = "udp and (port 67 or 68)" REQUESTED_ADDR = "requested_addr" MESSAGE_TYPE = "message-type" @@ -89,6 +89,17 @@ class WatcherBase: self._address_data = address_data def process_client(self, ip_address, hostname, mac_address): + """Process a client.""" + return run_callback_threadsafe( + self.hass.loop, + self.async_process_client, + ip_address, + hostname, + mac_address, + ).result() + + @callback + def async_process_client(self, ip_address, hostname, mac_address): """Process a client.""" made_ip_address = make_ip_address(ip_address) @@ -101,7 +112,6 @@ class WatcherBase: return data = self._address_data.get(ip_address) - if ( data and data[MAC_ADDRESS] == mac_address @@ -111,12 +121,9 @@ class WatcherBase: # to process it return - self._address_data[ip_address] = {MAC_ADDRESS: mac_address, HOSTNAME: hostname} + data = {MAC_ADDRESS: mac_address, HOSTNAME: hostname} + self._address_data[ip_address] = data - self.process_updated_address_data(ip_address, self._address_data[ip_address]) - - def process_updated_address_data(self, ip_address, data): - """Process the address data update.""" lowercase_hostname = data[HOSTNAME].lower() uppercase_mac = data[MAC_ADDRESS].upper() @@ -139,23 +146,17 @@ class WatcherBase: continue _LOGGER.debug("Matched %s against %s", data, entry) - - self.create_task( - self.hass.config_entries.flow.async_init( - entry["domain"], - context={"source": DOMAIN}, - data={ - IP_ADDRESS: ip_address, - HOSTNAME: lowercase_hostname, - MAC_ADDRESS: data[MAC_ADDRESS], - }, - ) + discovery_flow.async_create_flow( + self.hass, + entry["domain"], + {"source": config_entries.SOURCE_DHCP}, + { + IP_ADDRESS: ip_address, + HOSTNAME: lowercase_hostname, + MAC_ADDRESS: data[MAC_ADDRESS], + }, ) - @abstractmethod - def create_task(self, task): - """Pass a task to async_add_task based on which context we are in.""" - class NetworkWatcher(WatcherBase): """Class to query ptr records routers.""" @@ -189,21 +190,17 @@ class NetworkWatcher(WatcherBase): """Start a new discovery task if one is not running.""" if self._discover_task and not self._discover_task.done(): return - self._discover_task = self.create_task(self.async_discover()) + self._discover_task = self.hass.async_create_task(self.async_discover()) async def async_discover(self): """Process discovery.""" for host in await self._discover_hosts.async_discover(): - self.process_client( + self.async_process_client( host[DISCOVERY_IP_ADDRESS], host[DISCOVERY_HOSTNAME], _format_mac(host[DISCOVERY_MAC_ADDRESS]), ) - def create_task(self, task): - """Pass a task to async_create_task since we are in async context.""" - return self.hass.async_create_task(task) - class DeviceTrackerWatcher(WatcherBase): """Class to watch dhcp data from routers.""" @@ -250,11 +247,7 @@ class DeviceTrackerWatcher(WatcherBase): if ip_address is None or mac_address is None: return - self.process_client(ip_address, hostname, _format_mac(mac_address)) - - def create_task(self, task): - """Pass a task to async_create_task since we are in async context.""" - return self.hass.async_create_task(task) + self.async_process_client(ip_address, hostname, _format_mac(mac_address)) class DHCPWatcher(WatcherBase): @@ -353,10 +346,6 @@ class DHCPWatcher(WatcherBase): if self._sniffer.thread: self._sniffer.thread.name = self.__class__.__name__ - def create_task(self, task): - """Pass a task to hass.add_job since we are in a thread.""" - return self.hass.add_job(task) - def _decode_dhcp_option(dhcp_options, key): """Extract and decode data from a packet option.""" diff --git a/homeassistant/components/ssdp/__init__.py b/homeassistant/components/ssdp/__init__.py index b06f1b34493..da46fc565d2 100644 --- a/homeassistant/components/ssdp/__init__.py +++ b/homeassistant/components/ssdp/__init__.py @@ -18,19 +18,14 @@ from async_upnp_client.utils import CaseInsensitiveDict from homeassistant import config_entries from homeassistant.components import network -from homeassistant.const import ( - EVENT_HOMEASSISTANT_STARTED, - EVENT_HOMEASSISTANT_STOP, - MATCH_ALL, -) +from homeassistant.const import EVENT_HOMEASSISTANT_STOP, MATCH_ALL from homeassistant.core import HomeAssistant, callback as core_callback +from homeassistant.helpers import discovery_flow from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.event import async_track_time_interval from homeassistant.helpers.typing import ConfigType from homeassistant.loader import async_get_ssdp, bind_hass -from .flow import FlowDispatcher, SSDPFlow - DOMAIN = "ssdp" SCAN_INTERVAL = timedelta(seconds=60) @@ -222,7 +217,6 @@ class Scanner: self._cancel_scan: Callable[[], None] | None = None self._ssdp_listeners: list[SsdpListener] = [] self._callbacks: list[tuple[SsdpCallback, dict[str, str]]] = [] - self._flow_dispatcher: FlowDispatcher | None = None self._description_cache: DescriptionCache | None = None self.integration_matchers = integration_matchers @@ -327,14 +321,10 @@ class Scanner: session = async_get_clientsession(self.hass) requester = AiohttpSessionRequester(session, True, 10) self._description_cache = DescriptionCache(requester) - self._flow_dispatcher = FlowDispatcher(self.hass) await self._async_start_ssdp_listeners() self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, self.async_stop) - self.hass.bus.async_listen_once( - EVENT_HOMEASSISTANT_STARTED, self._flow_dispatcher.async_start - ) self._cancel_scan = async_track_time_interval( self.hass, self.async_scan, SCAN_INTERVAL ) @@ -417,13 +407,12 @@ class Scanner: for domain in matching_domains: _LOGGER.debug("Discovered %s at %s", domain, location) - flow: SSDPFlow = { - "domain": domain, - "context": {"source": config_entries.SOURCE_SSDP}, - "data": discovery_info, - } - assert self._flow_dispatcher is not None - self._flow_dispatcher.create(flow) + discovery_flow.async_create_flow( + self.hass, + domain, + {"source": config_entries.SOURCE_SSDP}, + discovery_info, + ) async def _async_get_description_dict( self, location: str | None diff --git a/homeassistant/components/ssdp/flow.py b/homeassistant/components/ssdp/flow.py deleted file mode 100644 index 77f4cb107b8..00000000000 --- a/homeassistant/components/ssdp/flow.py +++ /dev/null @@ -1,50 +0,0 @@ -"""The SSDP integration.""" -from __future__ import annotations - -from collections.abc import Coroutine -from typing import Any, TypedDict - -from homeassistant.core import HomeAssistant, callback -from homeassistant.data_entry_flow import FlowResult - - -class SSDPFlow(TypedDict): - """A queued ssdp discovery flow.""" - - domain: str - context: dict[str, Any] - data: dict - - -class FlowDispatcher: - """Dispatch discovery flows.""" - - def __init__(self, hass: HomeAssistant) -> None: - """Init the discovery dispatcher.""" - self.hass = hass - self.pending_flows: list[SSDPFlow] = [] - self.started = False - - @callback - def async_start(self, *_: Any) -> None: - """Start processing pending flows.""" - self.started = True - self.hass.loop.call_soon(self._async_process_pending_flows) - - def _async_process_pending_flows(self) -> None: - for flow in self.pending_flows: - self.hass.async_create_task(self._init_flow(flow)) - self.pending_flows = [] - - def create(self, flow: SSDPFlow) -> None: - """Create and add or queue a flow.""" - if self.started: - self.hass.async_create_task(self._init_flow(flow)) - else: - self.pending_flows.append(flow) - - def _init_flow(self, flow: SSDPFlow) -> Coroutine[None, None, FlowResult]: - """Create a flow.""" - return self.hass.config_entries.flow.async_init( - flow["domain"], context=flow["context"], data=flow["data"] - ) diff --git a/homeassistant/components/usb/__init__.py b/homeassistant/components/usb/__init__.py index 095d72f3ed4..80d01417ea7 100644 --- a/homeassistant/components/usb/__init__.py +++ b/homeassistant/components/usb/__init__.py @@ -16,13 +16,12 @@ from homeassistant.components import websocket_api from homeassistant.components.websocket_api.connection import ActiveConnection from homeassistant.const import EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP from homeassistant.core import Event, HomeAssistant, callback -from homeassistant.helpers import system_info +from homeassistant.helpers import discovery_flow, system_info from homeassistant.helpers.debounce import Debouncer from homeassistant.helpers.typing import ConfigType from homeassistant.loader import async_get_usb from .const import DOMAIN -from .flow import FlowDispatcher, USBFlow from .models import USBDevice from .utils import usb_device_from_port @@ -65,7 +64,7 @@ def get_serial_by_id(dev_path: str) -> str: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up the USB Discovery integration.""" usb = await async_get_usb(hass) - usb_discovery = USBDiscovery(hass, FlowDispatcher(hass), usb) + usb_discovery = USBDiscovery(hass, usb) await usb_discovery.async_setup() hass.data[DOMAIN] = usb_discovery websocket_api.async_register_command(hass, websocket_usb_scan) @@ -86,12 +85,10 @@ class USBDiscovery: def __init__( self, hass: HomeAssistant, - flow_dispatcher: FlowDispatcher, usb: list[dict[str, str]], ) -> None: """Init USB Discovery.""" self.hass = hass - self.flow_dispatcher = flow_dispatcher self.usb = usb self.seen: set[tuple[str, ...]] = set() self.observer_active = False @@ -104,7 +101,6 @@ class USBDiscovery: async def async_start(self, event: Event) -> None: """Start USB Discovery and run a manual scan.""" - self.flow_dispatcher.async_start() await self._async_scan_serial() async def _async_start_monitor(self) -> None: @@ -193,12 +189,12 @@ class USBDiscovery: if len(matcher) < most_matched_fields: break - flow: USBFlow = { - "domain": matcher["domain"], - "context": {"source": config_entries.SOURCE_USB}, - "data": dataclasses.asdict(device), - } - self.flow_dispatcher.async_create(flow) + discovery_flow.async_create_flow( + self.hass, + matcher["domain"], + {"source": config_entries.SOURCE_USB}, + dataclasses.asdict(device), + ) @callback def _async_process_ports(self, ports: list[ListPortInfo]) -> None: diff --git a/homeassistant/components/usb/flow.py b/homeassistant/components/usb/flow.py deleted file mode 100644 index 00c40add92a..00000000000 --- a/homeassistant/components/usb/flow.py +++ /dev/null @@ -1,48 +0,0 @@ -"""The USB Discovery integration.""" -from __future__ import annotations - -from collections.abc import Coroutine -from typing import Any, TypedDict - -from homeassistant.core import HomeAssistant, callback -from homeassistant.data_entry_flow import FlowResult - - -class USBFlow(TypedDict): - """A queued usb discovery flow.""" - - domain: str - context: dict[str, Any] - data: dict - - -class FlowDispatcher: - """Dispatch discovery flows.""" - - def __init__(self, hass: HomeAssistant) -> None: - """Init the discovery dispatcher.""" - self.hass = hass - self.pending_flows: list[USBFlow] = [] - self.started = False - - @callback - def async_start(self, *_: Any) -> None: - """Start processing pending flows.""" - self.started = True - for flow in self.pending_flows: - self.hass.async_create_task(self._init_flow(flow)) - self.pending_flows = [] - - @callback - def async_create(self, flow: USBFlow) -> None: - """Create and add or queue a flow.""" - if self.started: - self.hass.async_create_task(self._init_flow(flow)) - else: - self.pending_flows.append(flow) - - def _init_flow(self, flow: USBFlow) -> Coroutine[None, None, FlowResult]: - """Create a flow.""" - return self.hass.config_entries.flow.async_init( - flow["domain"], context=flow["context"], data=flow["data"] - ) diff --git a/homeassistant/components/zeroconf/__init__.py b/homeassistant/components/zeroconf/__init__.py index 4afb0a3c24d..1d72c7d20e9 100644 --- a/homeassistant/components/zeroconf/__init__.py +++ b/homeassistant/components/zeroconf/__init__.py @@ -2,7 +2,6 @@ from __future__ import annotations import asyncio -from collections.abc import Coroutine from contextlib import suppress import fnmatch from ipaddress import IPv4Address, IPv6Address, ip_address @@ -21,12 +20,11 @@ from homeassistant.components.network import async_get_source_ip from homeassistant.components.network.models import Adapter from homeassistant.const import ( EVENT_HOMEASSISTANT_START, - EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP, __version__, ) from homeassistant.core import Event, HomeAssistant, callback -from homeassistant.data_entry_flow import FlowResult +from homeassistant.helpers import discovery_flow import homeassistant.helpers.config_validation as cv from homeassistant.helpers.network import NoURLAvailableError, get_url from homeassistant.helpers.typing import ConfigType @@ -91,14 +89,6 @@ class HaServiceInfo(TypedDict): properties: dict[str, Any] -class ZeroconfFlow(TypedDict): - """A queued zeroconf discovery flow.""" - - domain: str - context: dict[str, Any] - data: HaServiceInfo - - @bind_hass async def async_get_instance(hass: HomeAssistant) -> HaZeroconf: """Zeroconf instance to be shared with other integrations that use it.""" @@ -192,17 +182,11 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: uuid = await hass.helpers.instance_id.async_get() await _async_register_hass_zc_service(hass, aio_zc, uuid) - @callback - def _async_start_discovery(_event: Event) -> None: - """Start processing flows.""" - discovery.async_start() - async def _async_zeroconf_hass_stop(_event: Event) -> None: await discovery.async_stop() hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, _async_zeroconf_hass_stop) hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, _async_zeroconf_hass_start) - hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STARTED, _async_start_discovery) return True @@ -288,40 +272,6 @@ async def _async_register_hass_zc_service( await aio_zc.async_register_service(info, allow_name_change=True) -class FlowDispatcher: - """Dispatch discovery flows.""" - - def __init__(self, hass: HomeAssistant) -> None: - """Init the discovery dispatcher.""" - self.hass = hass - self.pending_flows: list[ZeroconfFlow] = [] - self.started = False - - @callback - def async_start(self) -> None: - """Start processing pending flows.""" - self.started = True - self.hass.loop.call_soon(self._async_process_pending_flows) - - def _async_process_pending_flows(self) -> None: - for flow in self.pending_flows: - self.hass.async_create_task(self._init_flow(flow)) - self.pending_flows = [] - - def async_create(self, flow: ZeroconfFlow) -> None: - """Create and add or queue a flow.""" - if self.started: - self.hass.async_create_task(self._init_flow(flow)) - else: - self.pending_flows.append(flow) - - def _init_flow(self, flow: ZeroconfFlow) -> Coroutine[None, None, FlowResult]: - """Create a flow.""" - return self.hass.config_entries.flow.async_init( - flow["domain"], context=flow["context"], data=flow["data"] - ) - - class ZeroconfDiscovery: """Discovery via zeroconf.""" @@ -340,12 +290,10 @@ class ZeroconfDiscovery: self.homekit_models = homekit_models self.ipv6 = ipv6 - self.flow_dispatcher: FlowDispatcher | None = None self.async_service_browser: HaAsyncServiceBrowser | None = None async def async_setup(self) -> None: """Start discovery.""" - self.flow_dispatcher = FlowDispatcher(self.hass) types = list(self.zeroconf_types) # We want to make sure we know about other HomeAssistant # instances as soon as possible to avoid name conflicts @@ -363,12 +311,6 @@ class ZeroconfDiscovery: if self.async_service_browser: await self.async_service_browser.async_cancel() - @callback - def async_start(self) -> None: - """Start processing discovery flows.""" - assert self.flow_dispatcher is not None - self.flow_dispatcher.async_start() - @callback def async_service_update( self, @@ -404,12 +346,14 @@ class ZeroconfDiscovery: return _LOGGER.debug("Discovered new device %s %s", name, info) - assert self.flow_dispatcher is not None # If we can handle it as a HomeKit discovery, we do that here. if service_type in HOMEKIT_TYPES: - if pending_flow := handle_homekit(self.hass, self.homekit_models, info): - self.flow_dispatcher.async_create(pending_flow) + props = info["properties"] + if domain := async_get_homekit_discovery_domain(self.homekit_models, props): + discovery_flow.async_create_flow( + self.hass, domain, {"source": config_entries.SOURCE_HOMEKIT}, info + ) # Continue on here as homekit_controller # still needs to get updates on devices # so it can see when the 'c#' field is updated. @@ -417,10 +361,10 @@ class ZeroconfDiscovery: # We only send updates to homekit_controller # if the device is already paired in order to avoid # offering a second discovery for the same device - if pending_flow and HOMEKIT_PAIRED_STATUS_FLAG in info["properties"]: + if domain and HOMEKIT_PAIRED_STATUS_FLAG in props: try: # 0 means paired and not discoverable by iOS clients) - if int(info["properties"][HOMEKIT_PAIRED_STATUS_FLAG]): + if int(props[HOMEKIT_PAIRED_STATUS_FLAG]): return except ValueError: # HomeKit pairing status unknown @@ -466,24 +410,22 @@ class ZeroconfDiscovery: ): continue - flow: ZeroconfFlow = { - "domain": matcher["domain"], - "context": {"source": config_entries.SOURCE_ZEROCONF}, - "data": info, - } - self.flow_dispatcher.async_create(flow) + discovery_flow.async_create_flow( + self.hass, + matcher["domain"], + {"source": config_entries.SOURCE_ZEROCONF}, + info, + ) -def handle_homekit( - hass: HomeAssistant, homekit_models: dict[str, str], info: HaServiceInfo -) -> ZeroconfFlow | None: +def async_get_homekit_discovery_domain( + homekit_models: dict[str, str], props: dict[str, Any] +) -> str | None: """Handle a HomeKit discovery. - Return if discovery was forwarded. + Return the domain to forward the discovery data to """ model = None - props = info["properties"] - for key in props: if key.lower() == HOMEKIT_MODEL: model = props[key] @@ -500,11 +442,7 @@ def handle_homekit( ): continue - return { - "domain": homekit_models[test_model], - "context": {"source": config_entries.SOURCE_HOMEKIT}, - "data": info, - } + return homekit_models[test_model] return None diff --git a/homeassistant/data_entry_flow.py b/homeassistant/data_entry_flow.py index 63d5566db40..791fd9d21c5 100644 --- a/homeassistant/data_entry_flow.py +++ b/homeassistant/data_entry_flow.py @@ -120,6 +120,19 @@ class FlowManager(abc.ABC): async def async_post_init(self, flow: FlowHandler, result: FlowResult) -> None: """Entry has finished executing its first step asynchronously.""" + @callback + def async_has_matching_flow( + self, handler: str, context: dict[str, Any], data: Any + ) -> bool: + """Check if an existing matching flow is in progress with the same handler, context, and data.""" + return any( + flow + for flow in self._progress.values() + if flow.handler == handler + and flow.context["source"] == context["source"] + and flow.init_data == data + ) + @callback def async_progress(self, include_uninitialized: bool = False) -> list[FlowResult]: """Return the flows in progress.""" @@ -173,6 +186,7 @@ class FlowManager(abc.ABC): flow.handler = handler flow.flow_id = uuid.uuid4().hex flow.context = context + flow.init_data = data self._progress[flow.flow_id] = flow result = await self._async_handle_step(flow, flow.init_step, data, init_done) return flow, result @@ -318,6 +332,9 @@ class FlowHandler: # Set by _async_create_flow callback init_step = "init" + # The initial data that was used to start the flow + init_data: Any = None + # Set by developer VERSION = 1 diff --git a/homeassistant/helpers/discovery_flow.py b/homeassistant/helpers/discovery_flow.py new file mode 100644 index 00000000000..5bb0da2dc05 --- /dev/null +++ b/homeassistant/helpers/discovery_flow.py @@ -0,0 +1,82 @@ +"""The discovery flow helper.""" +from __future__ import annotations + +from collections.abc import Coroutine +from typing import Any + +from homeassistant.const import EVENT_HOMEASSISTANT_STARTED +from homeassistant.core import CoreState, Event, HomeAssistant, callback +from homeassistant.data_entry_flow import FlowResult +from homeassistant.loader import bind_hass +from homeassistant.util.async_ import gather_with_concurrency + +FLOW_INIT_LIMIT = 2 +DISCOVERY_FLOW_DISPATCHER = "discovery_flow_disptacher" + + +@bind_hass +@callback +def async_create_flow( + hass: HomeAssistant, domain: str, context: dict[str, Any], data: Any +) -> None: + """Create a discovery flow.""" + if hass.state == CoreState.running: + if init_coro := _async_init_flow(hass, domain, context, data): + hass.async_create_task(init_coro) + return + + if DISCOVERY_FLOW_DISPATCHER not in hass.data: + dispatcher = hass.data[DISCOVERY_FLOW_DISPATCHER] = FlowDispatcher(hass) + dispatcher.async_setup() + else: + dispatcher = hass.data[DISCOVERY_FLOW_DISPATCHER] + + return dispatcher.async_create(domain, context, data) + + +@callback +def _async_init_flow( + hass: HomeAssistant, domain: str, context: dict[str, Any], data: Any +) -> Coroutine[None, None, FlowResult] | None: + """Create a discovery flow.""" + # Avoid spawning flows that have the same initial discovery data + # as ones in progress as it may cause additional device probing + # which can overload devices since zeroconf/ssdp updates can happen + # multiple times in the same minute + if hass.config_entries.flow.async_has_matching_flow(domain, context, data): + return None + + return hass.config_entries.flow.async_init(domain, context=context, data=data) + + +class FlowDispatcher: + """Dispatch discovery flows.""" + + def __init__(self, hass: HomeAssistant) -> None: + """Init the discovery dispatcher.""" + self.hass = hass + self.pending_flows: list[tuple[str, dict[str, Any], Any]] = [] + + @callback + def async_setup(self) -> None: + """Set up the flow disptcher.""" + self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STARTED, self.async_start) + + @callback + def async_start(self, event: Event) -> None: + """Start processing pending flows.""" + self.hass.data.pop(DISCOVERY_FLOW_DISPATCHER) + self.hass.async_create_task(self._async_process_pending_flows()) + + async def _async_process_pending_flows(self) -> None: + """Process any pending discovery flows.""" + init_coros = [_async_init_flow(self.hass, *flow) for flow in self.pending_flows] + await gather_with_concurrency( + FLOW_INIT_LIMIT, + *[init_coro for init_coro in init_coros if init_coro is not None], + ) + + @callback + def async_create(self, domain: str, context: dict[str, Any], data: Any) -> None: + """Create and add or queue a flow.""" + self.pending_flows.append((domain, context, data)) diff --git a/tests/components/dhcp/test_init.py b/tests/components/dhcp/test_init.py index f00a0135e8d..dc50edbeb10 100644 --- a/tests/components/dhcp/test_init.py +++ b/tests/components/dhcp/test_init.py @@ -3,6 +3,7 @@ import datetime import threading from unittest.mock import MagicMock, patch +from scapy import arch # pylint: unused-import # noqa: F401 from scapy.error import Scapy_Exception from scapy.layers.dhcp import DHCP from scapy.layers.l2 import Ether @@ -16,6 +17,7 @@ from homeassistant.components.device_tracker.const import ( ATTR_SOURCE_TYPE, SOURCE_TYPE_ROUTER, ) +from homeassistant.components.dhcp.const import DOMAIN from homeassistant.const import ( EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP, @@ -129,11 +131,16 @@ async def _async_get_handle_dhcp_packet(hass, integration_matchers): {}, integration_matchers, ) - handle_dhcp_packet = None + async_handle_dhcp_packet = None def _mock_sniffer(*args, **kwargs): - nonlocal handle_dhcp_packet - handle_dhcp_packet = kwargs["prn"] + nonlocal async_handle_dhcp_packet + callback = kwargs["prn"] + + async def _async_handle_dhcp_packet(packet): + await hass.async_add_executor_job(callback, packet) + + async_handle_dhcp_packet = _async_handle_dhcp_packet return MagicMock() with patch("homeassistant.components.dhcp._verify_l2socket_setup",), patch( @@ -141,7 +148,7 @@ async def _async_get_handle_dhcp_packet(hass, integration_matchers): ), patch("scapy.sendrecv.AsyncSniffer", _mock_sniffer): await dhcp_watcher.async_start() - return handle_dhcp_packet + return async_handle_dhcp_packet async def test_dhcp_match_hostname_and_macaddress(hass): @@ -151,11 +158,13 @@ async def test_dhcp_match_hostname_and_macaddress(hass): ] packet = Ether(RAW_DHCP_REQUEST) - handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers) + async_handle_dhcp_packet = await _async_get_handle_dhcp_packet( + hass, integration_matchers + ) with patch.object(hass.config_entries.flow, "async_init") as mock_init: - handle_dhcp_packet(packet) + await async_handle_dhcp_packet(packet) # Ensure no change is ignored - handle_dhcp_packet(packet) + await async_handle_dhcp_packet(packet) assert len(mock_init.mock_calls) == 1 assert mock_init.mock_calls[0][1][0] == "mock-domain" @@ -177,11 +186,13 @@ async def test_dhcp_renewal_match_hostname_and_macaddress(hass): packet = Ether(RAW_DHCP_RENEWAL) - handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers) + async_handle_dhcp_packet = await _async_get_handle_dhcp_packet( + hass, integration_matchers + ) with patch.object(hass.config_entries.flow, "async_init") as mock_init: - handle_dhcp_packet(packet) + await async_handle_dhcp_packet(packet) # Ensure no change is ignored - handle_dhcp_packet(packet) + await async_handle_dhcp_packet(packet) assert len(mock_init.mock_calls) == 1 assert mock_init.mock_calls[0][1][0] == "mock-domain" @@ -201,9 +212,11 @@ async def test_dhcp_match_hostname(hass): packet = Ether(RAW_DHCP_REQUEST) - handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers) + async_handle_dhcp_packet = await _async_get_handle_dhcp_packet( + hass, integration_matchers + ) with patch.object(hass.config_entries.flow, "async_init") as mock_init: - handle_dhcp_packet(packet) + await async_handle_dhcp_packet(packet) assert len(mock_init.mock_calls) == 1 assert mock_init.mock_calls[0][1][0] == "mock-domain" @@ -223,9 +236,11 @@ async def test_dhcp_match_macaddress(hass): packet = Ether(RAW_DHCP_REQUEST) - handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers) + async_handle_dhcp_packet = await _async_get_handle_dhcp_packet( + hass, integration_matchers + ) with patch.object(hass.config_entries.flow, "async_init") as mock_init: - handle_dhcp_packet(packet) + await async_handle_dhcp_packet(packet) assert len(mock_init.mock_calls) == 1 assert mock_init.mock_calls[0][1][0] == "mock-domain" @@ -245,9 +260,11 @@ async def test_dhcp_match_macaddress_without_hostname(hass): packet = Ether(RAW_DHCP_REQUEST_WITHOUT_HOSTNAME) - handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers) + async_handle_dhcp_packet = await _async_get_handle_dhcp_packet( + hass, integration_matchers + ) with patch.object(hass.config_entries.flow, "async_init") as mock_init: - handle_dhcp_packet(packet) + await async_handle_dhcp_packet(packet) assert len(mock_init.mock_calls) == 1 assert mock_init.mock_calls[0][1][0] == "mock-domain" @@ -267,9 +284,11 @@ async def test_dhcp_nomatch(hass): packet = Ether(RAW_DHCP_REQUEST) - handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers) + async_handle_dhcp_packet = await _async_get_handle_dhcp_packet( + hass, integration_matchers + ) with patch.object(hass.config_entries.flow, "async_init") as mock_init: - handle_dhcp_packet(packet) + await async_handle_dhcp_packet(packet) assert len(mock_init.mock_calls) == 0 @@ -280,9 +299,11 @@ async def test_dhcp_nomatch_hostname(hass): packet = Ether(RAW_DHCP_REQUEST) - handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers) + async_handle_dhcp_packet = await _async_get_handle_dhcp_packet( + hass, integration_matchers + ) with patch.object(hass.config_entries.flow, "async_init") as mock_init: - handle_dhcp_packet(packet) + await async_handle_dhcp_packet(packet) assert len(mock_init.mock_calls) == 0 @@ -293,9 +314,11 @@ async def test_dhcp_nomatch_non_dhcp_packet(hass): packet = Ether(b"") - handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers) + async_handle_dhcp_packet = await _async_get_handle_dhcp_packet( + hass, integration_matchers + ) with patch.object(hass.config_entries.flow, "async_init") as mock_init: - handle_dhcp_packet(packet) + await async_handle_dhcp_packet(packet) assert len(mock_init.mock_calls) == 0 @@ -315,9 +338,11 @@ async def test_dhcp_nomatch_non_dhcp_request_packet(hass): ("hostname", b"connect"), ] - handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers) + async_handle_dhcp_packet = await _async_get_handle_dhcp_packet( + hass, integration_matchers + ) with patch.object(hass.config_entries.flow, "async_init") as mock_init: - handle_dhcp_packet(packet) + await async_handle_dhcp_packet(packet) assert len(mock_init.mock_calls) == 0 @@ -337,9 +362,11 @@ async def test_dhcp_invalid_hostname(hass): ("hostname", "connect"), ] - handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers) + async_handle_dhcp_packet = await _async_get_handle_dhcp_packet( + hass, integration_matchers + ) with patch.object(hass.config_entries.flow, "async_init") as mock_init: - handle_dhcp_packet(packet) + await async_handle_dhcp_packet(packet) assert len(mock_init.mock_calls) == 0 @@ -359,9 +386,11 @@ async def test_dhcp_missing_hostname(hass): ("hostname", None), ] - handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers) + async_handle_dhcp_packet = await _async_get_handle_dhcp_packet( + hass, integration_matchers + ) with patch.object(hass.config_entries.flow, "async_init") as mock_init: - handle_dhcp_packet(packet) + await async_handle_dhcp_packet(packet) assert len(mock_init.mock_calls) == 0 @@ -381,9 +410,11 @@ async def test_dhcp_invalid_option(hass): ("hostname"), ] - handle_dhcp_packet = await _async_get_handle_dhcp_packet(hass, integration_matchers) + async_handle_dhcp_packet = await _async_get_handle_dhcp_packet( + hass, integration_matchers + ) with patch.object(hass.config_entries.flow, "async_init") as mock_init: - handle_dhcp_packet(packet) + await async_handle_dhcp_packet(packet) assert len(mock_init.mock_calls) == 0 @@ -393,7 +424,7 @@ async def test_setup_and_stop(hass): assert await async_setup_component( hass, - dhcp.DOMAIN, + DOMAIN, {}, ) await hass.async_block_till_done() @@ -417,7 +448,7 @@ async def test_setup_fails_as_root(hass, caplog): assert await async_setup_component( hass, - dhcp.DOMAIN, + DOMAIN, {}, ) await hass.async_block_till_done() @@ -442,7 +473,7 @@ async def test_setup_fails_non_root(hass, caplog): assert await async_setup_component( hass, - dhcp.DOMAIN, + DOMAIN, {}, ) await hass.async_block_till_done() @@ -464,7 +495,7 @@ async def test_setup_fails_with_broken_libpcap(hass, caplog): assert await async_setup_component( hass, - dhcp.DOMAIN, + DOMAIN, {}, ) await hass.async_block_till_done() diff --git a/tests/components/ssdp/test_init.py b/tests/components/ssdp/test_init.py index 64edd9e8341..0304f8f067b 100644 --- a/tests/components/ssdp/test_init.py +++ b/tests/components/ssdp/test_init.py @@ -431,7 +431,9 @@ async def test_scan_with_registered_callback( "homeassistant.components.ssdp.async_get_ssdp", return_value={"mock-domain": [{"st": "mock-st"}]}, ) -async def test_getting_existing_headers(mock_get_ssdp, hass, aioclient_mock): +async def test_getting_existing_headers( + mock_get_ssdp, hass, aioclient_mock, mock_flow_init +): """Test getting existing/previously scanned headers.""" aioclient_mock.get( "http://1.1.1.1", diff --git a/tests/helpers/test_discovery_flow.py b/tests/helpers/test_discovery_flow.py new file mode 100644 index 00000000000..549848e5c7b --- /dev/null +++ b/tests/helpers/test_discovery_flow.py @@ -0,0 +1,71 @@ +"""Test the discovery flow helper.""" + +from unittest.mock import AsyncMock, call, patch + +import pytest + +from homeassistant import config_entries +from homeassistant.core import EVENT_HOMEASSISTANT_STARTED, CoreState +from homeassistant.helpers import discovery_flow + + +@pytest.fixture +def mock_flow_init(hass): + """Mock hass.config_entries.flow.async_init.""" + with patch.object( + hass.config_entries.flow, "async_init", return_value=AsyncMock() + ) as mock_init: + yield mock_init + + +async def test_async_create_flow(hass, mock_flow_init): + """Test we can create a flow.""" + discovery_flow.async_create_flow( + hass, + "hue", + {"source": config_entries.SOURCE_HOMEKIT}, + {"properties": {"id": "aa:bb:cc:dd:ee:ff"}}, + ) + assert mock_flow_init.mock_calls == [ + call( + "hue", + context={"source": "homekit"}, + data={"properties": {"id": "aa:bb:cc:dd:ee:ff"}}, + ) + ] + + +async def test_async_create_flow_deferred_until_started(hass, mock_flow_init): + """Test flows are deferred until started.""" + hass.state = CoreState.stopped + discovery_flow.async_create_flow( + hass, + "hue", + {"source": config_entries.SOURCE_HOMEKIT}, + {"properties": {"id": "aa:bb:cc:dd:ee:ff"}}, + ) + assert not mock_flow_init.mock_calls + hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) + await hass.async_block_till_done() + assert mock_flow_init.mock_calls == [ + call( + "hue", + context={"source": "homekit"}, + data={"properties": {"id": "aa:bb:cc:dd:ee:ff"}}, + ) + ] + + +async def test_async_create_flow_checks_existing_flows(hass, mock_flow_init): + """Test existing flows prevent an identical one from being creates.""" + with patch( + "homeassistant.data_entry_flow.FlowManager.async_has_matching_flow", + return_value=True, + ): + discovery_flow.async_create_flow( + hass, + "hue", + {"source": config_entries.SOURCE_HOMEKIT}, + {"properties": {"id": "aa:bb:cc:dd:ee:ff"}}, + ) + assert not mock_flow_init.mock_calls diff --git a/tests/test_data_entry_flow.py b/tests/test_data_entry_flow.py index 4b5777d86f8..0aa3c01d50f 100644 --- a/tests/test_data_entry_flow.py +++ b/tests/test_data_entry_flow.py @@ -6,6 +6,7 @@ import pytest import voluptuous as vol from homeassistant import config_entries, data_entry_flow +from homeassistant.core import HomeAssistant from homeassistant.util.decorator import Registry from tests.common import async_capture_events @@ -397,3 +398,54 @@ async def test_init_unknown_flow(manager): manager, "async_create_flow", return_value=None ): await manager.async_init("test") + + +async def test_async_has_matching_flow( + hass: HomeAssistant, manager: data_entry_flow.FlowManager +): + """Test we can check for matching flows.""" + manager.hass = hass + + @manager.mock_reg_handler("test") + class TestFlow(data_entry_flow.FlowHandler): + VERSION = 5 + + async def async_step_init(self, user_input=None): + return self.async_show_progress( + step_id="init", + progress_action="task_one", + ) + + result = await manager.async_init( + "test", + context={"source": config_entries.SOURCE_HOMEKIT}, + data={"properties": {"id": "aa:bb:cc:dd:ee:ff"}}, + ) + assert result["type"] == data_entry_flow.RESULT_TYPE_SHOW_PROGRESS + assert result["progress_action"] == "task_one" + assert len(manager.async_progress()) == 1 + + assert ( + manager.async_has_matching_flow( + "test", + {"source": config_entries.SOURCE_HOMEKIT}, + {"properties": {"id": "aa:bb:cc:dd:ee:ff"}}, + ) + is True + ) + assert ( + manager.async_has_matching_flow( + "test", + {"source": config_entries.SOURCE_SSDP}, + {"properties": {"id": "aa:bb:cc:dd:ee:ff"}}, + ) + is False + ) + assert ( + manager.async_has_matching_flow( + "other", + {"source": config_entries.SOURCE_HOMEKIT}, + {"properties": {"id": "aa:bb:cc:dd:ee:ff"}}, + ) + is False + )