mirror of
https://github.com/home-assistant/core.git
synced 2025-07-22 12:47:08 +00:00
Fix device discovery of OAuth2 config flows (#48326)
This commit is contained in:
parent
4f4a6fd6a5
commit
f0e5e616a7
@ -1049,11 +1049,13 @@ class ConfigFlow(data_entry_flow.FlowHandler):
|
|||||||
}
|
}
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def _async_in_progress(self) -> list[dict]:
|
def _async_in_progress(self, include_uninitialized: bool = False) -> list[dict]:
|
||||||
"""Return other in progress flows for current domain."""
|
"""Return other in progress flows for current domain."""
|
||||||
return [
|
return [
|
||||||
flw
|
flw
|
||||||
for flw in self.hass.config_entries.flow.async_progress()
|
for flw in self.hass.config_entries.flow.async_progress(
|
||||||
|
include_uninitialized=include_uninitialized
|
||||||
|
)
|
||||||
if flw["handler"] == self.handler and flw["flow_id"] != self.flow_id
|
if flw["handler"] == self.handler and flw["flow_id"] != self.flow_id
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -1093,7 +1095,7 @@ class ConfigFlow(data_entry_flow.FlowHandler):
|
|||||||
self._abort_if_unique_id_configured()
|
self._abort_if_unique_id_configured()
|
||||||
|
|
||||||
# Abort if any other flow for this handler is already in progress
|
# Abort if any other flow for this handler is already in progress
|
||||||
if self._async_in_progress():
|
if self._async_in_progress(include_uninitialized=True):
|
||||||
raise data_entry_flow.AbortFlow("already_in_progress")
|
raise data_entry_flow.AbortFlow("already_in_progress")
|
||||||
|
|
||||||
async def async_step_discovery(
|
async def async_step_discovery(
|
||||||
|
@ -94,17 +94,17 @@ class FlowManager(abc.ABC):
|
|||||||
"""Entry has finished executing its first step asynchronously."""
|
"""Entry has finished executing its first step asynchronously."""
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_progress(self) -> list[dict]:
|
def async_progress(self, include_uninitialized: bool = False) -> list[dict]:
|
||||||
"""Return the flows in progress."""
|
"""Return the flows in progress."""
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
"flow_id": flow.flow_id,
|
"flow_id": flow.flow_id,
|
||||||
"handler": flow.handler,
|
"handler": flow.handler,
|
||||||
"context": flow.context,
|
"context": flow.context,
|
||||||
"step_id": flow.cur_step["step_id"],
|
"step_id": flow.cur_step["step_id"] if flow.cur_step else None,
|
||||||
}
|
}
|
||||||
for flow in self._progress.values()
|
for flow in self._progress.values()
|
||||||
if flow.cur_step is not None
|
if include_uninitialized or flow.cur_step is not None
|
||||||
]
|
]
|
||||||
|
|
||||||
async def async_init(
|
async def async_init(
|
||||||
|
@ -245,8 +245,10 @@ class AbstractOAuth2FlowHandler(config_entries.ConfigFlow, metaclass=ABCMeta):
|
|||||||
if not implementations:
|
if not implementations:
|
||||||
return self.async_abort(reason="missing_configuration")
|
return self.async_abort(reason="missing_configuration")
|
||||||
|
|
||||||
if len(implementations) == 1:
|
req = http.current_request.get()
|
||||||
# Pick first implementation as we have only one.
|
if len(implementations) == 1 and req is not None:
|
||||||
|
# Pick first implementation if we have only one, but only
|
||||||
|
# if this is triggered by a user interaction (request).
|
||||||
self.flow_impl = list(implementations.values())[0]
|
self.flow_impl = list(implementations.values())[0]
|
||||||
return await self.async_step_auth()
|
return await self.async_step_auth()
|
||||||
|
|
||||||
@ -313,23 +315,7 @@ class AbstractOAuth2FlowHandler(config_entries.ConfigFlow, metaclass=ABCMeta):
|
|||||||
"""
|
"""
|
||||||
return self.async_create_entry(title=self.flow_impl.name, data=data)
|
return self.async_create_entry(title=self.flow_impl.name, data=data)
|
||||||
|
|
||||||
async def async_step_discovery(
|
|
||||||
self, discovery_info: dict[str, Any]
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Handle a flow initialized by discovery."""
|
|
||||||
await self.async_set_unique_id(self.DOMAIN)
|
|
||||||
|
|
||||||
if self.hass.config_entries.async_entries(self.DOMAIN):
|
|
||||||
return self.async_abort(reason="already_configured")
|
|
||||||
|
|
||||||
return await self.async_step_pick_implementation()
|
|
||||||
|
|
||||||
async_step_user = async_step_pick_implementation
|
async_step_user = async_step_pick_implementation
|
||||||
async_step_mqtt = async_step_discovery
|
|
||||||
async_step_ssdp = async_step_discovery
|
|
||||||
async_step_zeroconf = async_step_discovery
|
|
||||||
async_step_homekit = async_step_discovery
|
|
||||||
async_step_dhcp = async_step_discovery
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def async_register_implementation(
|
def async_register_implementation(
|
||||||
|
@ -108,7 +108,7 @@ async def test_get_services_error(hass):
|
|||||||
assert account_link.DATA_SERVICES not in hass.data
|
assert account_link.DATA_SERVICES not in hass.data
|
||||||
|
|
||||||
|
|
||||||
async def test_implementation(hass, flow_handler):
|
async def test_implementation(hass, flow_handler, current_request_with_host):
|
||||||
"""Test Cloud OAuth2 implementation."""
|
"""Test Cloud OAuth2 implementation."""
|
||||||
hass.data["cloud"] = None
|
hass.data["cloud"] = None
|
||||||
|
|
||||||
|
@ -123,7 +123,9 @@ async def test_full_flow(
|
|||||||
assert entry.state == config_entries.ENTRY_STATE_NOT_LOADED
|
assert entry.state == config_entries.ENTRY_STATE_NOT_LOADED
|
||||||
|
|
||||||
|
|
||||||
async def test_abort_if_authorization_timeout(hass, mock_impl):
|
async def test_abort_if_authorization_timeout(
|
||||||
|
hass, mock_impl, current_request_with_host
|
||||||
|
):
|
||||||
"""Check Somfy authorization timeout."""
|
"""Check Somfy authorization timeout."""
|
||||||
flow = config_flow.SomfyFlowHandler()
|
flow = config_flow.SomfyFlowHandler()
|
||||||
flow.hass = hass
|
flow.hass = hass
|
||||||
|
@ -15,7 +15,7 @@ from .common import ComponentFactory, new_profile_config
|
|||||||
|
|
||||||
|
|
||||||
async def test_binary_sensor(
|
async def test_binary_sensor(
|
||||||
hass: HomeAssistant, component_factory: ComponentFactory
|
hass: HomeAssistant, component_factory: ComponentFactory, current_request_with_host
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test binary sensor."""
|
"""Test binary sensor."""
|
||||||
in_bed_attribute = WITHINGS_MEASUREMENTS_MAP[Measurement.IN_BED]
|
in_bed_attribute = WITHINGS_MEASUREMENTS_MAP[Measurement.IN_BED]
|
||||||
|
@ -74,6 +74,7 @@ async def test_webhook_post(
|
|||||||
arg_user_id: Any,
|
arg_user_id: Any,
|
||||||
arg_appli: Any,
|
arg_appli: Any,
|
||||||
expected_code: int,
|
expected_code: int,
|
||||||
|
current_request_with_host,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test webhook callback."""
|
"""Test webhook callback."""
|
||||||
person0 = new_profile_config("person0", user_id)
|
person0 = new_profile_config("person0", user_id)
|
||||||
@ -107,6 +108,7 @@ async def test_webhook_head(
|
|||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
component_factory: ComponentFactory,
|
component_factory: ComponentFactory,
|
||||||
aiohttp_client,
|
aiohttp_client,
|
||||||
|
current_request_with_host,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test head method on webhook view."""
|
"""Test head method on webhook view."""
|
||||||
person0 = new_profile_config("person0", 0)
|
person0 = new_profile_config("person0", 0)
|
||||||
@ -124,6 +126,7 @@ async def test_webhook_put(
|
|||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
component_factory: ComponentFactory,
|
component_factory: ComponentFactory,
|
||||||
aiohttp_client,
|
aiohttp_client,
|
||||||
|
current_request_with_host,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test webhook callback."""
|
"""Test webhook callback."""
|
||||||
person0 = new_profile_config("person0", 0)
|
person0 = new_profile_config("person0", 0)
|
||||||
|
@ -34,7 +34,7 @@ async def test_config_non_unique_profile(hass: HomeAssistant) -> None:
|
|||||||
|
|
||||||
|
|
||||||
async def test_config_reauth_profile(
|
async def test_config_reauth_profile(
|
||||||
hass: HomeAssistant, aiohttp_client, aioclient_mock
|
hass: HomeAssistant, aiohttp_client, aioclient_mock, current_request_with_host
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test reauth an existing profile re-creates the config entry."""
|
"""Test reauth an existing profile re-creates the config entry."""
|
||||||
hass_config = {
|
hass_config = {
|
||||||
|
@ -125,7 +125,10 @@ async def test_async_setup_no_config(hass: HomeAssistant) -> None:
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
async def test_auth_failure(
|
async def test_auth_failure(
|
||||||
hass: HomeAssistant, component_factory: ComponentFactory, exception: Exception
|
hass: HomeAssistant,
|
||||||
|
component_factory: ComponentFactory,
|
||||||
|
exception: Exception,
|
||||||
|
current_request_with_host,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test auth failure."""
|
"""Test auth failure."""
|
||||||
person0 = new_profile_config(
|
person0 = new_profile_config(
|
||||||
|
@ -302,7 +302,7 @@ def async_assert_state_equals(
|
|||||||
|
|
||||||
|
|
||||||
async def test_sensor_default_enabled_entities(
|
async def test_sensor_default_enabled_entities(
|
||||||
hass: HomeAssistant, component_factory: ComponentFactory
|
hass: HomeAssistant, component_factory: ComponentFactory, current_request_with_host
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test entities enabled by default."""
|
"""Test entities enabled by default."""
|
||||||
entity_registry: EntityRegistry = er.async_get(hass)
|
entity_registry: EntityRegistry = er.async_get(hass)
|
||||||
@ -343,7 +343,7 @@ async def test_sensor_default_enabled_entities(
|
|||||||
|
|
||||||
|
|
||||||
async def test_all_entities(
|
async def test_all_entities(
|
||||||
hass: HomeAssistant, component_factory: ComponentFactory
|
hass: HomeAssistant, component_factory: ComponentFactory, current_request_with_host
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test all entities."""
|
"""Test all entities."""
|
||||||
entity_registry: EntityRegistry = er.async_get(hass)
|
entity_registry: EntityRegistry = er.async_get(hass)
|
||||||
|
@ -113,7 +113,9 @@ async def test_abort_if_no_implementation(hass, flow_handler):
|
|||||||
assert result["reason"] == "missing_configuration"
|
assert result["reason"] == "missing_configuration"
|
||||||
|
|
||||||
|
|
||||||
async def test_abort_if_authorization_timeout(hass, flow_handler, local_impl):
|
async def test_abort_if_authorization_timeout(
|
||||||
|
hass, flow_handler, local_impl, current_request_with_host
|
||||||
|
):
|
||||||
"""Check timeout generating authorization url."""
|
"""Check timeout generating authorization url."""
|
||||||
flow_handler.async_register_implementation(hass, local_impl)
|
flow_handler.async_register_implementation(hass, local_impl)
|
||||||
|
|
||||||
@ -129,7 +131,9 @@ async def test_abort_if_authorization_timeout(hass, flow_handler, local_impl):
|
|||||||
assert result["reason"] == "authorize_url_timeout"
|
assert result["reason"] == "authorize_url_timeout"
|
||||||
|
|
||||||
|
|
||||||
async def test_abort_if_no_url_available(hass, flow_handler, local_impl):
|
async def test_abort_if_no_url_available(
|
||||||
|
hass, flow_handler, local_impl, current_request_with_host
|
||||||
|
):
|
||||||
"""Check no_url_available generating authorization url."""
|
"""Check no_url_available generating authorization url."""
|
||||||
flow_handler.async_register_implementation(hass, local_impl)
|
flow_handler.async_register_implementation(hass, local_impl)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user