diff --git a/homeassistant/components/knx/config_flow.py b/homeassistant/components/knx/config_flow.py index e2ff0908bbe..c043ea65ee5 100644 --- a/homeassistant/components/knx/config_flow.py +++ b/homeassistant/components/knx/config_flow.py @@ -2,6 +2,7 @@ from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import AsyncGenerator from typing import Any, Final import voluptuous as vol @@ -95,6 +96,9 @@ class KNXCommonFlow(ABC, FlowHandler): self._found_tunnels: list[GatewayDescriptor] = [] self._selected_tunnel: GatewayDescriptor | None = None + self._gatewayscanner: GatewayScanner | None = None + self._async_scan_gen: AsyncGenerator[GatewayDescriptor, None] | None = None + @abstractmethod def finish_flow(self, title: str) -> FlowResult: """Finish the flow.""" @@ -104,6 +108,13 @@ class KNXCommonFlow(ABC, FlowHandler): ) -> FlowResult: """Handle connection type configuration.""" if user_input is not None: + if self._async_scan_gen: + await self._async_scan_gen.aclose() # stop the scan + self._async_scan_gen = None + if self._gatewayscanner: + self._found_gateways = list( + self._gatewayscanner.found_gateways.values() + ) connection_type = user_input[CONF_KNX_CONNECTION_TYPE] if connection_type == CONF_KNX_ROUTING: return await self.async_step_routing() @@ -129,8 +140,21 @@ class KNXCommonFlow(ABC, FlowHandler): CONF_KNX_TUNNELING: CONF_KNX_TUNNELING.capitalize(), CONF_KNX_ROUTING: CONF_KNX_ROUTING.capitalize(), } - self._found_gateways = await scan_for_gateways() - if self._found_gateways: + + if isinstance(self, OptionsFlow) and (knx_module := self.hass.data.get(DOMAIN)): + xknx = knx_module.xknx + else: + xknx = XKNX() + self._gatewayscanner = GatewayScanner( + xknx, stop_on_found=0, timeout_in_seconds=2 + ) + # keep a reference to the generator to scan in background until user selects a connection type + self._async_scan_gen = self._gatewayscanner.async_scan() + try: + await self._async_scan_gen.__anext__() # pylint: disable=unnecessary-dunder-call + except StopAsyncIteration: + pass # scan finished, no interfaces discovered + else: # add automatic at first position only if a gateway responded supported_connection_types = { CONF_KNX_AUTOMATIC: CONF_KNX_AUTOMATIC.capitalize() @@ -614,12 +638,3 @@ class KNXOptionsFlow(KNXCommonFlow, OptionsFlow): data_schema=vol.Schema(data_schema), last_step=True, ) - - -async def scan_for_gateways(stop_on_found: int = 0) -> list[GatewayDescriptor]: - """Scan for gateways within the network.""" - xknx = XKNX() - gatewayscanner = GatewayScanner( - xknx, stop_on_found=stop_on_found, timeout_in_seconds=2 - ) - return await gatewayscanner.scan() diff --git a/tests/components/knx/test_config_flow.py b/tests/components/knx/test_config_flow.py index 2ce3793937b..eb593d20924 100644 --- a/tests/components/knx/test_config_flow.py +++ b/tests/components/knx/test_config_flow.py @@ -1,5 +1,5 @@ """Test the KNX config flow.""" -from unittest.mock import patch +from unittest.mock import Mock, patch import pytest from xknx.exceptions.exception import InvalidSecureConfiguration @@ -67,6 +67,24 @@ def _gateway_descriptor( return descriptor +class GatewayScannerMock: + """Mock GatewayScanner.""" + + def __init__(self, gateways=None): + """Initialize GatewayScannerMock.""" + # Key is a HPAI instance in xknx, but not used in HA anyway. + self.found_gateways = ( + {f"{gateway.ip_addr}:{gateway.port}": gateway for gateway in gateways} + if gateways + else {} + ) + + async def async_scan(self): + """Mock async generator.""" + for gateway in self.found_gateways: + yield gateway + + async def test_user_single_instance(hass): """Test we only allow a single config flow.""" MockConfigEntry(domain=DOMAIN).add_to_hass(hass) @@ -78,15 +96,17 @@ async def test_user_single_instance(hass): assert result["reason"] == "single_instance_allowed" -async def test_routing_setup(hass: HomeAssistant) -> None: +@patch( + "homeassistant.components.knx.config_flow.GatewayScanner", + return_value=GatewayScannerMock(), +) +async def test_routing_setup(gateway_scanner_mock, hass: HomeAssistant) -> None: """Test routing setup.""" - with patch("xknx.io.gateway_scanner.GatewayScanner.scan") as gateways: - gateways.return_value = [] - result = await hass.config_entries.flow.async_init( - DOMAIN, context={"source": config_entries.SOURCE_USER} - ) - assert result["type"] == FlowResultType.FORM - assert not result["errors"] + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + assert result["type"] == FlowResultType.FORM + assert not result["errors"] result2 = await hass.config_entries.flow.async_configure( result["flow_id"], @@ -126,19 +146,23 @@ async def test_routing_setup(hass: HomeAssistant) -> None: assert len(mock_setup_entry.mock_calls) == 1 -async def test_routing_setup_advanced(hass: HomeAssistant) -> None: +@patch( + "homeassistant.components.knx.config_flow.GatewayScanner", + return_value=GatewayScannerMock(), +) +async def test_routing_setup_advanced( + gateway_scanner_mock, hass: HomeAssistant +) -> None: """Test routing setup with advanced options.""" - with patch("xknx.io.gateway_scanner.GatewayScanner.scan") as gateways: - gateways.return_value = [] - result = await hass.config_entries.flow.async_init( - DOMAIN, - context={ - "source": config_entries.SOURCE_USER, - "show_advanced_options": True, - }, - ) - assert result["type"] == FlowResultType.FORM - assert not result["errors"] + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={ + "source": config_entries.SOURCE_USER, + "show_advanced_options": True, + }, + ) + assert result["type"] == FlowResultType.FORM + assert not result["errors"] result2 = await hass.config_entries.flow.async_configure( result["flow_id"], @@ -200,15 +224,19 @@ async def test_routing_setup_advanced(hass: HomeAssistant) -> None: assert len(mock_setup_entry.mock_calls) == 1 -async def test_routing_secure_manual_setup(hass: HomeAssistant) -> None: +@patch( + "homeassistant.components.knx.config_flow.GatewayScanner", + return_value=GatewayScannerMock(), +) +async def test_routing_secure_manual_setup( + gateway_scanner_mock, hass: HomeAssistant +) -> None: """Test routing secure setup with manual key config.""" - with patch("xknx.io.gateway_scanner.GatewayScanner.scan") as gateways: - gateways.return_value = [] - result = await hass.config_entries.flow.async_init( - DOMAIN, context={"source": config_entries.SOURCE_USER} - ) - assert result["type"] == FlowResultType.FORM - assert not result["errors"] + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + assert result["type"] == FlowResultType.FORM + assert not result["errors"] result2 = await hass.config_entries.flow.async_configure( result["flow_id"], @@ -287,15 +315,19 @@ async def test_routing_secure_manual_setup(hass: HomeAssistant) -> None: assert len(mock_setup_entry.mock_calls) == 1 -async def test_routing_secure_keyfile(hass: HomeAssistant) -> None: +@patch( + "homeassistant.components.knx.config_flow.GatewayScanner", + return_value=GatewayScannerMock(), +) +async def test_routing_secure_keyfile( + gateway_scanner_mock, hass: HomeAssistant +) -> None: """Test routing secure setup with keyfile.""" - with patch("xknx.io.gateway_scanner.GatewayScanner.scan") as gateways: - gateways.return_value = [] - result = await hass.config_entries.flow.async_init( - DOMAIN, context={"source": config_entries.SOURCE_USER} - ) - assert result["type"] == FlowResultType.FORM - assert not result["errors"] + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + assert result["type"] == FlowResultType.FORM + assert not result["errors"] result2 = await hass.config_entries.flow.async_configure( result["flow_id"], @@ -412,17 +444,19 @@ async def test_routing_secure_keyfile(hass: HomeAssistant) -> None: ), ], ) +@patch( + "homeassistant.components.knx.config_flow.GatewayScanner", + return_value=GatewayScannerMock(), +) async def test_tunneling_setup_manual( - hass: HomeAssistant, user_input, config_entry_data + gateway_scanner_mock, hass: HomeAssistant, user_input, config_entry_data ) -> None: """Test tunneling if no gateway was found found (or `manual` option was chosen).""" - with patch("xknx.io.gateway_scanner.GatewayScanner.scan") as gateways: - gateways.return_value = [] - result = await hass.config_entries.flow.async_init( - DOMAIN, context={"source": config_entries.SOURCE_USER} - ) - assert result["type"] == FlowResultType.FORM - assert not result["errors"] + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + assert result["type"] == FlowResultType.FORM + assert not result["errors"] result2 = await hass.config_entries.flow.async_configure( result["flow_id"], @@ -451,19 +485,23 @@ async def test_tunneling_setup_manual( assert len(mock_setup_entry.mock_calls) == 1 -async def test_tunneling_setup_for_local_ip(hass: HomeAssistant) -> None: +@patch( + "homeassistant.components.knx.config_flow.GatewayScanner", + return_value=GatewayScannerMock(), +) +async def test_tunneling_setup_for_local_ip( + gateway_scanner_mock, hass: HomeAssistant +) -> None: """Test tunneling if only one gateway is found.""" - with patch("xknx.io.gateway_scanner.GatewayScanner.scan") as gateways: - gateways.return_value = [] - result = await hass.config_entries.flow.async_init( - DOMAIN, - context={ - "source": config_entries.SOURCE_USER, - "show_advanced_options": True, - }, - ) - assert result["type"] == FlowResultType.FORM - assert not result["errors"] + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={ + "source": config_entries.SOURCE_USER, + "show_advanced_options": True, + }, + ) + assert result["type"] == FlowResultType.FORM + assert not result["errors"] result2 = await hass.config_entries.flow.async_configure( result["flow_id"], @@ -542,11 +580,13 @@ async def test_tunneling_setup_for_local_ip(hass: HomeAssistant) -> None: async def test_tunneling_setup_for_multiple_found_gateways(hass: HomeAssistant) -> None: - """Test tunneling if only one gateway is found.""" + """Test tunneling if multiple gateways are found.""" gateway = _gateway_descriptor("192.168.0.1", 3675) gateway2 = _gateway_descriptor("192.168.1.100", 3675) - with patch("xknx.io.gateway_scanner.GatewayScanner.scan") as gateways: - gateways.return_value = [gateway, gateway2] + with patch( + "homeassistant.components.knx.config_flow.GatewayScanner" + ) as gateway_scanner_mock: + gateway_scanner_mock.return_value = GatewayScannerMock([gateway, gateway2]) result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": config_entries.SOURCE_USER} ) @@ -601,8 +641,10 @@ async def test_manual_tunnel_step_with_found_gateway( hass: HomeAssistant, gateway ) -> None: """Test manual tunnel if gateway was found and tunneling is selected.""" - with patch("xknx.io.gateway_scanner.GatewayScanner.scan") as gateways: - gateways.return_value = [gateway] + with patch( + "homeassistant.components.knx.config_flow.GatewayScanner" + ) as gateway_scanner_mock: + gateway_scanner_mock.return_value = GatewayScannerMock([gateway]) result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": config_entries.SOURCE_USER} ) @@ -634,8 +676,12 @@ async def test_manual_tunnel_step_with_found_gateway( async def test_form_with_automatic_connection_handling(hass: HomeAssistant) -> None: """Test we get the form.""" - with patch("xknx.io.gateway_scanner.GatewayScanner.scan") as gateways: - gateways.return_value = [_gateway_descriptor("192.168.0.1", 3675)] + with patch( + "homeassistant.components.knx.config_flow.GatewayScanner" + ) as gateway_scanner_mock: + gateway_scanner_mock.return_value = GatewayScannerMock( + [_gateway_descriptor("192.168.0.1", 3675)] + ) result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": config_entries.SOURCE_USER} ) @@ -672,8 +718,10 @@ async def _get_menu_step(hass: HomeAssistant) -> FlowResult: supports_tunnelling_tcp=True, requires_secure=True, ) - with patch("xknx.io.gateway_scanner.GatewayScanner.scan") as gateways: - gateways.return_value = [gateway] + with patch( + "homeassistant.components.knx.config_flow.GatewayScanner" + ) as gateway_scanner_mock: + gateway_scanner_mock.return_value = GatewayScannerMock([gateway]) result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": config_entries.SOURCE_USER} ) @@ -711,8 +759,10 @@ async def test_get_secure_menu_step_manual_tunnelling( supports_tunnelling_tcp=True, requires_secure=True, ) - with patch("xknx.io.gateway_scanner.GatewayScanner.scan") as gateways: - gateways.return_value = [gateway] + with patch( + "homeassistant.components.knx.config_flow.GatewayScanner" + ) as gateway_scanner_mock: + gateway_scanner_mock.return_value = GatewayScannerMock([gateway]) result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": config_entries.SOURCE_USER} ) @@ -899,12 +949,15 @@ async def test_options_flow_connection_type( ) -> None: """Test options flow changing interface.""" mock_config_entry.add_to_hass(hass) + hass.data[DOMAIN] = Mock() # GatewayScanner uses running XKNX() instance gateway = _gateway_descriptor("192.168.0.1", 3675) menu_step = await hass.config_entries.options.async_init(mock_config_entry.entry_id) - with patch("xknx.io.gateway_scanner.GatewayScanner.scan") as gateways: - gateways.return_value = [gateway] + with patch( + "homeassistant.components.knx.config_flow.GatewayScanner" + ) as gateway_scanner_mock: + gateway_scanner_mock.return_value = GatewayScannerMock([gateway]) result = await hass.config_entries.options.async_configure( menu_step["flow_id"], {"next_step_id": "connection_type"},