Add TypeVar default for FlowResult (#112345)

Co-authored-by: Marc Mueller <30130371+cdce8p@users.noreply.github.com>
This commit is contained in:
Erik Montnemery 2024-03-05 22:52:11 +01:00 committed by GitHub
parent 33fe6ad647
commit 3d3e9900c3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
14 changed files with 77 additions and 81 deletions

View File

@ -91,8 +91,6 @@ async def auth_manager_from_config(
class AuthManagerFlowManager(data_entry_flow.FlowManager):
"""Manage authentication flows."""
_flow_result = FlowResult
def __init__(self, hass: HomeAssistant, auth_manager: AuthManager) -> None:
"""Init auth manager flows."""
super().__init__(hass)
@ -112,7 +110,7 @@ class AuthManagerFlowManager(data_entry_flow.FlowManager):
return await auth_provider.async_login_flow(context)
async def async_finish_flow(
self, flow: data_entry_flow.BaseFlowHandler, result: FlowResult
self, flow: data_entry_flow.FlowHandler, result: FlowResult
) -> FlowResult:
"""Return a user as result of login flow."""
flow = cast(LoginFlow, flow)

View File

@ -96,8 +96,6 @@ class MultiFactorAuthModule:
class SetupFlow(data_entry_flow.FlowHandler):
"""Handler for the setup flow."""
_flow_result = FlowResult
def __init__(
self, auth_module: MultiFactorAuthModule, setup_schema: vol.Schema, user_id: str
) -> None:

View File

@ -184,8 +184,6 @@ async def load_auth_provider_module(
class LoginFlow(data_entry_flow.FlowHandler):
"""Handler for the login flow."""
_flow_result = FlowResult
def __init__(self, auth_provider: AuthProvider) -> None:
"""Initialize the login flow."""
self._auth_provider = auth_provider

View File

@ -38,8 +38,6 @@ _LOGGER = logging.getLogger(__name__)
class MfaFlowManager(data_entry_flow.FlowManager):
"""Manage multi factor authentication flows."""
_flow_result = data_entry_flow.FlowResult
async def async_create_flow( # type: ignore[override]
self,
handler_key: str,
@ -56,7 +54,7 @@ class MfaFlowManager(data_entry_flow.FlowManager):
return await mfa_module.async_setup_flow(user_id)
async def async_finish_flow(
self, flow: data_entry_flow.BaseFlowHandler, result: data_entry_flow.FlowResult
self, flow: data_entry_flow.FlowHandler, result: data_entry_flow.FlowResult
) -> data_entry_flow.FlowResult:
"""Complete an mfs setup flow."""
_LOGGER.debug("flow_result: %s", result)

View File

@ -141,7 +141,9 @@ def _prepare_config_flow_result_json(
return data
class ConfigManagerFlowIndexView(FlowManagerIndexView):
class ConfigManagerFlowIndexView(
FlowManagerIndexView[config_entries.ConfigEntriesFlowManager]
):
"""View to create config flows."""
url = "/api/config/config_entries/flow"
@ -196,7 +198,9 @@ class ConfigManagerFlowIndexView(FlowManagerIndexView):
return _prepare_config_flow_result_json(result, super()._prepare_result_json)
class ConfigManagerFlowResourceView(FlowManagerResourceView):
class ConfigManagerFlowResourceView(
FlowManagerResourceView[config_entries.ConfigEntriesFlowManager]
):
"""View to interact with the flow manager."""
url = "/api/config/config_entries/flow/{flow_id}"
@ -238,7 +242,9 @@ class ConfigManagerAvailableFlowView(HomeAssistantView):
return self.json(await async_get_config_flows(hass, **kwargs))
class OptionManagerFlowIndexView(FlowManagerIndexView):
class OptionManagerFlowIndexView(
FlowManagerIndexView[config_entries.OptionsFlowManager]
):
"""View to create option flows."""
url = "/api/config/config_entries/options/flow"
@ -255,7 +261,9 @@ class OptionManagerFlowIndexView(FlowManagerIndexView):
return await super().post(request)
class OptionManagerFlowResourceView(FlowManagerResourceView):
class OptionManagerFlowResourceView(
FlowManagerResourceView[config_entries.OptionsFlowManager]
):
"""View to interact with the option flow manager."""
url = "/api/config/config_entries/options/flow/{flow_id}"

View File

@ -48,11 +48,9 @@ class ConfirmRepairFlow(RepairsFlow):
)
class RepairsFlowManager(data_entry_flow.BaseFlowManager[data_entry_flow.FlowResult]):
class RepairsFlowManager(data_entry_flow.FlowManager):
"""Manage repairs flows."""
_flow_result = data_entry_flow.FlowResult
async def async_create_flow(
self,
handler_key: str,
@ -84,7 +82,7 @@ class RepairsFlowManager(data_entry_flow.BaseFlowManager[data_entry_flow.FlowRes
return flow
async def async_finish_flow(
self, flow: data_entry_flow.BaseFlowHandler, result: data_entry_flow.FlowResult
self, flow: data_entry_flow.FlowHandler, result: data_entry_flow.FlowResult
) -> data_entry_flow.FlowResult:
"""Complete a fix flow."""
if result.get("type") != data_entry_flow.FlowResultType.ABORT:

View File

@ -10,8 +10,6 @@ from homeassistant.core import HomeAssistant
class RepairsFlow(data_entry_flow.FlowHandler):
"""Handle a flow for fixing an issue."""
_flow_result = data_entry_flow.FlowResult
issue_id: str
data: dict[str, str | int | float | None] | None

View File

@ -34,7 +34,7 @@ from homeassistant.config_entries import (
)
from homeassistant.const import CONF_NAME, CONF_URL
from homeassistant.core import HomeAssistant, callback
from homeassistant.data_entry_flow import AbortFlow, BaseFlowManager
from homeassistant.data_entry_flow import AbortFlow, FlowManager
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.aiohttp_client import async_get_clientsession
@ -182,7 +182,7 @@ class BaseZwaveJSFlow(ConfigEntryBaseFlow, ABC):
@property
@abstractmethod
def flow_manager(self) -> BaseFlowManager:
def flow_manager(self) -> FlowManager[ConfigFlowResult]:
"""Return the flow manager of the flow."""
async def async_step_install_addon(

View File

@ -1045,7 +1045,7 @@ class FlowCancelledError(Exception):
"""Error to indicate that a flow has been cancelled."""
class ConfigEntriesFlowManager(data_entry_flow.BaseFlowManager[ConfigFlowResult]):
class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
"""Manage all the config entry flows that are in progress."""
_flow_result = ConfigFlowResult
@ -1170,7 +1170,9 @@ class ConfigEntriesFlowManager(data_entry_flow.BaseFlowManager[ConfigFlowResult]
self._discovery_debouncer.async_shutdown()
async def async_finish_flow(
self, flow: data_entry_flow.BaseFlowHandler, result: ConfigFlowResult
self,
flow: data_entry_flow.FlowHandler[ConfigFlowResult],
result: ConfigFlowResult,
) -> ConfigFlowResult:
"""Finish a config flow and add an entry."""
flow = cast(ConfigFlow, flow)
@ -1290,7 +1292,9 @@ class ConfigEntriesFlowManager(data_entry_flow.BaseFlowManager[ConfigFlowResult]
return flow
async def async_post_init(
self, flow: data_entry_flow.BaseFlowHandler, result: ConfigFlowResult
self,
flow: data_entry_flow.FlowHandler[ConfigFlowResult],
result: ConfigFlowResult,
) -> None:
"""After a flow is initialised trigger new flow notifications."""
source = flow.context["source"]
@ -1936,7 +1940,7 @@ def _async_abort_entries_match(
raise data_entry_flow.AbortFlow("already_configured")
class ConfigEntryBaseFlow(data_entry_flow.BaseFlowHandler[ConfigFlowResult]):
class ConfigEntryBaseFlow(data_entry_flow.FlowHandler[ConfigFlowResult]):
"""Base class for config and option flows."""
_flow_result = ConfigFlowResult
@ -2288,7 +2292,7 @@ class ConfigFlow(ConfigEntryBaseFlow):
return self.async_abort(reason=reason)
class OptionsFlowManager(data_entry_flow.BaseFlowManager[ConfigFlowResult]):
class OptionsFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
"""Flow to set options for a configuration entry."""
_flow_result = ConfigFlowResult
@ -2317,7 +2321,9 @@ class OptionsFlowManager(data_entry_flow.BaseFlowManager[ConfigFlowResult]):
return handler.async_get_options_flow(entry)
async def async_finish_flow(
self, flow: data_entry_flow.BaseFlowHandler, result: ConfigFlowResult
self,
flow: data_entry_flow.FlowHandler[ConfigFlowResult],
result: ConfigFlowResult,
) -> ConfigFlowResult:
"""Finish an options flow and update options for configuration entry.
@ -2337,7 +2343,9 @@ class OptionsFlowManager(data_entry_flow.BaseFlowManager[ConfigFlowResult]):
result["result"] = True
return result
async def _async_setup_preview(self, flow: data_entry_flow.BaseFlowHandler) -> None:
async def _async_setup_preview(
self, flow: data_entry_flow.FlowHandler[ConfigFlowResult]
) -> None:
"""Set up preview for an option flow handler."""
entry = self._async_get_config_entry(flow.handler)
await _load_integration(self.hass, entry.domain, {})

View File

@ -11,8 +11,9 @@ from enum import StrEnum
from functools import partial
import logging
from types import MappingProxyType
from typing import Any, Generic, Required, TypedDict, TypeVar
from typing import Any, Generic, Required, TypedDict
from typing_extensions import TypeVar
import voluptuous as vol
from .core import HomeAssistant, callback
@ -84,7 +85,7 @@ STEP_ID_OPTIONAL_STEPS = {
}
_FlowResultT = TypeVar("_FlowResultT", bound="FlowResult")
_FlowResultT = TypeVar("_FlowResultT", bound="FlowResult", default="FlowResult")
@dataclass(slots=True)
@ -188,10 +189,10 @@ def _map_error_to_schema_errors(
schema_errors[path_part_str] = error.error_message
class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
class FlowManager(abc.ABC, Generic[_FlowResultT]):
"""Manage all the flows that are in progress."""
_flow_result: Callable[..., _FlowResultT]
_flow_result: Callable[..., _FlowResultT] = FlowResult # type: ignore[assignment]
def __init__(
self,
@ -200,9 +201,9 @@ class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
"""Initialize the flow manager."""
self.hass = hass
self._preview: set[str] = set()
self._progress: dict[str, BaseFlowHandler] = {}
self._handler_progress_index: dict[str, set[BaseFlowHandler]] = {}
self._init_data_process_index: dict[type, set[BaseFlowHandler]] = {}
self._progress: dict[str, FlowHandler[_FlowResultT]] = {}
self._handler_progress_index: dict[str, set[FlowHandler[_FlowResultT]]] = {}
self._init_data_process_index: dict[type, set[FlowHandler[_FlowResultT]]] = {}
@abc.abstractmethod
async def async_create_flow(
@ -211,7 +212,7 @@ class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
*,
context: dict[str, Any] | None = None,
data: dict[str, Any] | None = None,
) -> BaseFlowHandler[_FlowResultT]:
) -> FlowHandler[_FlowResultT]:
"""Create a flow for specified handler.
Handler key is the domain of the component that we want to set up.
@ -219,12 +220,12 @@ class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
@abc.abstractmethod
async def async_finish_flow(
self, flow: BaseFlowHandler, result: _FlowResultT
self, flow: FlowHandler[_FlowResultT], result: _FlowResultT
) -> _FlowResultT:
"""Finish a data entry flow."""
async def async_post_init(
self, flow: BaseFlowHandler, result: _FlowResultT
self, flow: FlowHandler[_FlowResultT], result: _FlowResultT
) -> None:
"""Entry has finished executing its first step asynchronously."""
@ -298,7 +299,7 @@ class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
@callback
def _async_progress_by_handler(
self, handler: str, match_context: dict[str, Any] | None
) -> list[BaseFlowHandler[_FlowResultT]]:
) -> list[FlowHandler[_FlowResultT]]:
"""Return the flows in progress by handler.
If match_context is specified, only return flows with a context that
@ -362,7 +363,7 @@ class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
data_schema := cur_step.get("data_schema")
) is not None and user_input is not None:
try:
user_input = data_schema(user_input)
user_input = data_schema(user_input) # type: ignore[operator]
except vol.Invalid as ex:
raised_errors = [ex]
if isinstance(ex, vol.MultipleInvalid):
@ -444,7 +445,7 @@ class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
self._async_remove_flow_progress(flow_id)
@callback
def _async_add_flow_progress(self, flow: BaseFlowHandler[_FlowResultT]) -> None:
def _async_add_flow_progress(self, flow: FlowHandler[_FlowResultT]) -> None:
"""Add a flow to in progress."""
if flow.init_data is not None:
init_data_type = type(flow.init_data)
@ -453,9 +454,7 @@ class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
self._handler_progress_index.setdefault(flow.handler, set()).add(flow)
@callback
def _async_remove_flow_from_index(
self, flow: BaseFlowHandler[_FlowResultT]
) -> None:
def _async_remove_flow_from_index(self, flow: FlowHandler[_FlowResultT]) -> None:
"""Remove a flow from in progress."""
if flow.init_data is not None:
init_data_type = type(flow.init_data)
@ -481,7 +480,7 @@ class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
async def _async_handle_step(
self,
flow: BaseFlowHandler[_FlowResultT],
flow: FlowHandler[_FlowResultT],
step_id: str,
user_input: dict | BaseServiceInfo | None,
) -> _FlowResultT:
@ -558,7 +557,7 @@ class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
return result
def _raise_if_step_does_not_exist(
self, flow: BaseFlowHandler, step_id: str
self, flow: FlowHandler[_FlowResultT], step_id: str
) -> None:
"""Raise if the step does not exist."""
method = f"async_step_{step_id}"
@ -569,7 +568,7 @@ class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
f"Handler {self.__class__.__name__} doesn't support step {step_id}"
)
async def _async_setup_preview(self, flow: BaseFlowHandler) -> None:
async def _async_setup_preview(self, flow: FlowHandler[_FlowResultT]) -> None:
"""Set up preview for a flow handler."""
if flow.handler not in self._preview:
self._preview.add(flow.handler)
@ -577,7 +576,7 @@ class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
@callback
def _async_flow_handler_to_flow_result(
self, flows: Iterable[BaseFlowHandler], include_uninitialized: bool
self, flows: Iterable[FlowHandler[_FlowResultT]], include_uninitialized: bool
) -> list[_FlowResultT]:
"""Convert a list of FlowHandler to a partial FlowResult that can be serialized."""
results = []
@ -595,16 +594,10 @@ class BaseFlowManager(abc.ABC, Generic[_FlowResultT]):
return results
class FlowManager(BaseFlowManager[FlowResult]):
"""Manage all the flows that are in progress."""
_flow_result = FlowResult
class BaseFlowHandler(Generic[_FlowResultT]):
class FlowHandler(Generic[_FlowResultT]):
"""Handle a data entry flow."""
_flow_result: Callable[..., _FlowResultT]
_flow_result: Callable[..., _FlowResultT] = FlowResult # type: ignore[assignment]
# Set by flow manager
cur_step: _FlowResultT | None = None
@ -881,12 +874,6 @@ class BaseFlowHandler(Generic[_FlowResultT]):
self.__progress_task = progress_task
class FlowHandler(BaseFlowHandler[FlowResult]):
"""Handle a data entry flow."""
_flow_result = FlowResult
# These can be removed if no deprecated constant are in this module anymore
__getattr__ = partial(check_if_deprecated_constant, module_globals=globals())
__dir__ = partial(

View File

@ -2,9 +2,10 @@
from __future__ import annotations
from http import HTTPStatus
from typing import Any
from typing import Any, Generic
from aiohttp import web
from typing_extensions import TypeVar
import voluptuous as vol
import voluptuous_serialize
@ -14,11 +15,17 @@ from homeassistant.components.http.data_validator import RequestDataValidator
from . import config_validation as cv
_FlowManagerT = TypeVar(
"_FlowManagerT",
bound=data_entry_flow.FlowManager[Any],
default=data_entry_flow.FlowManager,
)
class _BaseFlowManagerView(HomeAssistantView):
class _BaseFlowManagerView(HomeAssistantView, Generic[_FlowManagerT]):
"""Foundation for flow manager views."""
def __init__(self, flow_mgr: data_entry_flow.BaseFlowManager) -> None:
def __init__(self, flow_mgr: _FlowManagerT) -> None:
"""Initialize the flow manager index view."""
self._flow_mgr = flow_mgr
@ -48,7 +55,7 @@ class _BaseFlowManagerView(HomeAssistantView):
return data
class FlowManagerIndexView(_BaseFlowManagerView):
class FlowManagerIndexView(_BaseFlowManagerView[_FlowManagerT]):
"""View to create config flows."""
@RequestDataValidator(
@ -96,7 +103,7 @@ class FlowManagerIndexView(_BaseFlowManagerView):
return {"show_advanced_options": data["show_advanced_options"]}
class FlowManagerResourceView(_BaseFlowManagerView):
class FlowManagerResourceView(_BaseFlowManagerView[_FlowManagerT]):
"""View to interact with the flow manager."""
async def get(self, request: web.Request, /, flow_id: str) -> web.Response:

View File

@ -63,7 +63,7 @@ async def test_async_create_flow_checks_existing_flows_after_startup(
"""Test existing flows prevent an identical ones from being after startup."""
hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED)
with patch(
"homeassistant.data_entry_flow.BaseFlowManager.async_has_matching_flow",
"homeassistant.data_entry_flow.FlowManager.async_has_matching_flow",
return_value=True,
):
discovery_flow.async_create_flow(

View File

@ -45,7 +45,7 @@ def manager_fixture():
handlers = Registry()
entries = []
class FlowManager(data_entry_flow.BaseFlowManager):
class FlowManager(data_entry_flow.FlowManager):
"""Test flow manager."""
async def async_create_flow(self, handler_key, *, context, data):
@ -105,7 +105,7 @@ async def test_name(hass: HomeAssistant, entity_registry: er.EntityRegistry) ->
@pytest.mark.parametrize("marker", (vol.Required, vol.Optional))
async def test_config_flow_advanced_option(
hass: HomeAssistant, manager: data_entry_flow.BaseFlowManager, marker
hass: HomeAssistant, manager: data_entry_flow.FlowManager, marker
) -> None:
"""Test handling of advanced options in config flow."""
manager.hass = hass
@ -200,7 +200,7 @@ async def test_config_flow_advanced_option(
@pytest.mark.parametrize("marker", (vol.Required, vol.Optional))
async def test_options_flow_advanced_option(
hass: HomeAssistant, manager: data_entry_flow.BaseFlowManager, marker
hass: HomeAssistant, manager: data_entry_flow.FlowManager, marker
) -> None:
"""Test handling of advanced options in options flow."""
manager.hass = hass
@ -475,7 +475,7 @@ async def test_next_step_function(hass: HomeAssistant) -> None:
async def test_suggested_values(
hass: HomeAssistant, manager: data_entry_flow.BaseFlowManager
hass: HomeAssistant, manager: data_entry_flow.FlowManager
) -> None:
"""Test suggested_values handling in SchemaFlowFormStep."""
manager.hass = hass
@ -667,7 +667,7 @@ async def test_options_flow_state(hass: HomeAssistant) -> None:
async def test_options_flow_omit_optional_keys(
hass: HomeAssistant, manager: data_entry_flow.BaseFlowManager
hass: HomeAssistant, manager: data_entry_flow.FlowManager
) -> None:
"""Test handling of advanced options in options flow."""
manager.hass = hass

View File

@ -24,11 +24,9 @@ def manager():
handlers = Registry()
entries = []
class FlowManager(data_entry_flow.BaseFlowManager):
class FlowManager(data_entry_flow.FlowManager):
"""Test flow manager."""
_flow_result = data_entry_flow.FlowResult
async def async_create_flow(self, handler_key, *, context, data):
"""Test create flow."""
handler = handlers.get(handler_key)
@ -81,7 +79,7 @@ async def test_configure_reuses_handler_instance(manager) -> None:
assert len(manager.mock_created_entries) == 0
async def test_configure_two_steps(manager: data_entry_flow.BaseFlowManager) -> None:
async def test_configure_two_steps(manager: data_entry_flow.FlowManager) -> None:
"""Test that we reuse instances."""
@manager.mock_reg_handler("test")
@ -258,7 +256,7 @@ async def test_finish_callback_change_result_type(hass: HomeAssistant) -> None:
step_id="init", data_schema=vol.Schema({"count": int})
)
class FlowManager(data_entry_flow.BaseFlowManager):
class FlowManager(data_entry_flow.FlowManager):
async def async_create_flow(self, handler_name, *, context, data):
"""Create a test flow."""
return TestFlow()
@ -775,7 +773,7 @@ async def test_async_get_unknown_flow(manager) -> None:
async def test_async_has_matching_flow(
hass: HomeAssistant, manager: data_entry_flow.BaseFlowManager
hass: HomeAssistant, manager: data_entry_flow.FlowManager
) -> None:
"""Test we can check for matching flows."""
manager.hass = hass
@ -951,7 +949,7 @@ async def test_show_menu(hass: HomeAssistant, manager, menu_options) -> None:
async def test_find_flows_by_init_data_type(
manager: data_entry_flow.BaseFlowManager,
manager: data_entry_flow.FlowManager,
) -> None:
"""Test we can find flows by init data type."""