Fix DiscoveryFlowHandler when discovery_function returns bool (#133563)

Co-authored-by: J. Nick Koston <nick@koston.org>
This commit is contained in:
Konrad Vité 2025-01-16 23:31:16 +01:00 committed by GitHub
parent e5164496cf
commit e6c696933f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 63 additions and 10 deletions

View File

@ -67,9 +67,11 @@ class DiscoveryFlowHandler[_R: Awaitable[bool] | bool](config_entries.ConfigFlow
in_progress = self._async_in_progress() in_progress = self._async_in_progress()
if not (has_devices := bool(in_progress)): if not (has_devices := bool(in_progress)):
has_devices = await cast( discovery_result = self._discovery_function(self.hass)
"asyncio.Future[bool]", self._discovery_function(self.hass) if isinstance(discovery_result, bool):
) has_devices = discovery_result
else:
has_devices = await cast("asyncio.Future[bool]", discovery_result)
if not has_devices: if not has_devices:
return self.async_abort(reason="no_devices_found") return self.async_abort(reason="no_devices_found")

View File

@ -1,6 +1,8 @@
"""Tests for the Config Entry Flow helper.""" """Tests for the Config Entry Flow helper."""
from collections.abc import Generator import asyncio
from collections.abc import Callable, Generator
from contextlib import contextmanager
from unittest.mock import Mock, PropertyMock, patch from unittest.mock import Mock, PropertyMock, patch
import pytest import pytest
@ -13,22 +15,44 @@ from homeassistant.helpers import config_entry_flow
from tests.common import MockConfigEntry, MockModule, mock_integration, mock_platform from tests.common import MockConfigEntry, MockModule, mock_integration, mock_platform
@contextmanager
def _make_discovery_flow_conf(
has_discovered_devices: Callable[[], asyncio.Future[bool] | bool],
) -> Generator[None]:
with patch.dict(config_entries.HANDLERS):
config_entry_flow.register_discovery_flow(
"test", "Test", has_discovered_devices
)
yield
@pytest.fixture @pytest.fixture
def discovery_flow_conf(hass: HomeAssistant) -> Generator[dict[str, bool]]: def async_discovery_flow_conf(hass: HomeAssistant) -> Generator[dict[str, bool]]:
"""Register a handler.""" """Register a handler with an async discovery function."""
handler_conf = {"discovered": False} handler_conf = {"discovered": False}
async def has_discovered_devices(hass: HomeAssistant) -> bool: async def has_discovered_devices(hass: HomeAssistant) -> bool:
"""Mock if we have discovered devices.""" """Mock if we have discovered devices."""
return handler_conf["discovered"] return handler_conf["discovered"]
with patch.dict(config_entries.HANDLERS): with _make_discovery_flow_conf(has_discovered_devices):
config_entry_flow.register_discovery_flow(
"test", "Test", has_discovered_devices
)
yield handler_conf yield handler_conf
@pytest.fixture
def discovery_flow_conf(hass: HomeAssistant) -> Generator[dict[str, bool]]:
"""Register a handler with a async friendly callback function."""
handler_conf = {"discovered": False}
def has_discovered_devices(hass: HomeAssistant) -> bool:
"""Mock if we have discovered devices."""
return handler_conf["discovered"]
with _make_discovery_flow_conf(has_discovered_devices):
yield handler_conf
handler_conf = {"discovered": False}
@pytest.fixture @pytest.fixture
def webhook_flow_conf(hass: HomeAssistant) -> Generator[None]: def webhook_flow_conf(hass: HomeAssistant) -> Generator[None]:
"""Register a handler.""" """Register a handler."""
@ -95,6 +119,33 @@ async def test_user_has_confirmation(
assert result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY assert result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY
async def test_user_has_confirmation_async_discovery_flow(
hass: HomeAssistant, async_discovery_flow_conf: dict[str, bool]
) -> None:
"""Test user requires confirmation to setup with an async has_discovered_devices."""
async_discovery_flow_conf["discovered"] = True
mock_platform(hass, "test.config_flow", None)
result = await hass.config_entries.flow.async_init(
"test", context={"source": config_entries.SOURCE_USER}, data={}
)
assert result["type"] == data_entry_flow.FlowResultType.FORM
assert result["step_id"] == "confirm"
progress = hass.config_entries.flow.async_progress()
assert len(progress) == 1
assert progress[0]["flow_id"] == result["flow_id"]
assert progress[0]["context"] == {
"confirm_only": True,
"source": config_entries.SOURCE_USER,
"unique_id": "test",
}
result = await hass.config_entries.flow.async_configure(result["flow_id"], {})
assert result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY
@pytest.mark.parametrize( @pytest.mark.parametrize(
"source", "source",
[ [