From a0e558c4576cd8228ea8af9d3533438721b22d4f Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Thu, 29 Feb 2024 16:52:39 +0100 Subject: [PATCH] Add generic classes BaseFlowHandler and BaseFlowManager (#111814) * Add generic classes BaseFlowHandler and BaseFlowManager * Migrate zwave_js * Update tests * Update tests * Address review comments --- homeassistant/auth/__init__.py | 4 +- homeassistant/auth/mfa_modules/__init__.py | 2 + homeassistant/auth/providers/__init__.py | 2 + .../components/auth/mfa_setup_flow.py | 4 +- .../components/repairs/issue_handler.py | 6 +- homeassistant/components/repairs/models.py | 4 +- .../components/zwave_js/config_flow.py | 100 +++++---- homeassistant/config_entries.py | 87 ++++---- homeassistant/data_entry_flow.py | 199 ++++++++++-------- homeassistant/helpers/config_entry_flow.py | 31 +-- .../helpers/config_entry_oauth2_flow.py | 17 +- homeassistant/helpers/data_entry_flow.py | 2 +- homeassistant/helpers/discovery_flow.py | 4 +- .../helpers/schema_config_entry_flow.py | 46 ++-- pylint/plugins/hass_enforce_type_hints.py | 49 +++-- .../config_flow/integration/config_flow.py | 7 +- tests/components/cloud/test_repairs.py | 2 - .../components/config/test_config_entries.py | 2 - tests/components/hassio/test_repairs.py | 12 -- tests/components/kitchen_sink/test_init.py | 2 - .../components/repairs/test_websocket_api.py | 2 - tests/helpers/test_discovery_flow.py | 2 +- .../helpers/test_schema_config_entry_flow.py | 10 +- tests/pylint/test_enforce_type_hints.py | 4 +- tests/test_data_entry_flow.py | 14 +- 25 files changed, 341 insertions(+), 273 deletions(-) diff --git a/homeassistant/auth/__init__.py b/homeassistant/auth/__init__.py index f99e90dbc05..a68f8bc95eb 100644 --- a/homeassistant/auth/__init__.py +++ b/homeassistant/auth/__init__.py @@ -91,6 +91,8 @@ async def auth_manager_from_config( class AuthManagerFlowManager(data_entry_flow.FlowManager): """Manage authentication flows.""" + _flow_result = FlowResult + def __init__(self, hass: HomeAssistant, auth_manager: AuthManager) -> None: """Init auth manager flows.""" super().__init__(hass) @@ -110,7 +112,7 @@ class AuthManagerFlowManager(data_entry_flow.FlowManager): return await auth_provider.async_login_flow(context) async def async_finish_flow( - self, flow: data_entry_flow.FlowHandler, result: FlowResult + self, flow: data_entry_flow.BaseFlowHandler, result: FlowResult ) -> FlowResult: """Return a user as result of login flow.""" flow = cast(LoginFlow, flow) diff --git a/homeassistant/auth/mfa_modules/__init__.py b/homeassistant/auth/mfa_modules/__init__.py index aa28710d8c6..3c8c0e3a096 100644 --- a/homeassistant/auth/mfa_modules/__init__.py +++ b/homeassistant/auth/mfa_modules/__init__.py @@ -96,6 +96,8 @@ class MultiFactorAuthModule: class SetupFlow(data_entry_flow.FlowHandler): """Handler for the setup flow.""" + _flow_result = FlowResult + def __init__( self, auth_module: MultiFactorAuthModule, setup_schema: vol.Schema, user_id: str ) -> None: diff --git a/homeassistant/auth/providers/__init__.py b/homeassistant/auth/providers/__init__.py index 7d74dd2dc26..577955d7c75 100644 --- a/homeassistant/auth/providers/__init__.py +++ b/homeassistant/auth/providers/__init__.py @@ -184,6 +184,8 @@ async def load_auth_provider_module( class LoginFlow(data_entry_flow.FlowHandler): """Handler for the login flow.""" + _flow_result = FlowResult + def __init__(self, auth_provider: AuthProvider) -> None: """Initialize the login flow.""" self._auth_provider = auth_provider diff --git a/homeassistant/components/auth/mfa_setup_flow.py b/homeassistant/components/auth/mfa_setup_flow.py index a7999af666a..58c45c56b85 100644 --- a/homeassistant/components/auth/mfa_setup_flow.py +++ b/homeassistant/components/auth/mfa_setup_flow.py @@ -38,6 +38,8 @@ _LOGGER = logging.getLogger(__name__) class MfaFlowManager(data_entry_flow.FlowManager): """Manage multi factor authentication flows.""" + _flow_result = data_entry_flow.FlowResult + async def async_create_flow( # type: ignore[override] self, handler_key: str, @@ -54,7 +56,7 @@ class MfaFlowManager(data_entry_flow.FlowManager): return await mfa_module.async_setup_flow(user_id) async def async_finish_flow( - self, flow: data_entry_flow.FlowHandler, result: data_entry_flow.FlowResult + self, flow: data_entry_flow.BaseFlowHandler, result: data_entry_flow.FlowResult ) -> data_entry_flow.FlowResult: """Complete an mfs setup flow.""" _LOGGER.debug("flow_result: %s", result) diff --git a/homeassistant/components/repairs/issue_handler.py b/homeassistant/components/repairs/issue_handler.py index f2ce3bac84e..388edc56254 100644 --- a/homeassistant/components/repairs/issue_handler.py +++ b/homeassistant/components/repairs/issue_handler.py @@ -48,9 +48,11 @@ class ConfirmRepairFlow(RepairsFlow): ) -class RepairsFlowManager(data_entry_flow.FlowManager): +class RepairsFlowManager(data_entry_flow.BaseFlowManager[data_entry_flow.FlowResult]): """Manage repairs flows.""" + _flow_result = data_entry_flow.FlowResult + async def async_create_flow( self, handler_key: str, @@ -82,7 +84,7 @@ class RepairsFlowManager(data_entry_flow.FlowManager): return flow async def async_finish_flow( - self, flow: data_entry_flow.FlowHandler, result: data_entry_flow.FlowResult + self, flow: data_entry_flow.BaseFlowHandler, result: data_entry_flow.FlowResult ) -> data_entry_flow.FlowResult: """Complete a fix flow.""" if result.get("type") != data_entry_flow.FlowResultType.ABORT: diff --git a/homeassistant/components/repairs/models.py b/homeassistant/components/repairs/models.py index 6ae175b29e9..63b3199141b 100644 --- a/homeassistant/components/repairs/models.py +++ b/homeassistant/components/repairs/models.py @@ -7,9 +7,11 @@ from homeassistant import data_entry_flow from homeassistant.core import HomeAssistant -class RepairsFlow(data_entry_flow.FlowHandler): +class RepairsFlow(data_entry_flow.BaseFlowHandler[data_entry_flow.FlowResult]): """Handle a flow for fixing an issue.""" + _flow_result = data_entry_flow.FlowResult + issue_id: str data: dict[str, str | int | float | None] | None diff --git a/homeassistant/components/zwave_js/config_flow.py b/homeassistant/components/zwave_js/config_flow.py index c3fd2836048..9eccb032120 100644 --- a/homeassistant/components/zwave_js/config_flow.py +++ b/homeassistant/components/zwave_js/config_flow.py @@ -11,7 +11,6 @@ from serial.tools import list_ports import voluptuous as vol from zwave_js_server.version import VersionInfo, get_server_version -from homeassistant import config_entries, exceptions from homeassistant.components import usb from homeassistant.components.hassio import ( AddonError, @@ -22,14 +21,21 @@ from homeassistant.components.hassio import ( is_hassio, ) from homeassistant.components.zeroconf import ZeroconfServiceInfo +from homeassistant.config_entries import ( + SOURCE_USB, + ConfigEntriesFlowManager, + ConfigEntry, + ConfigEntryBaseFlow, + ConfigEntryState, + ConfigFlow, + ConfigFlowResult, + OptionsFlow, + OptionsFlowManager, +) from homeassistant.const import CONF_NAME, CONF_URL from homeassistant.core import HomeAssistant, callback -from homeassistant.data_entry_flow import ( - AbortFlow, - FlowHandler, - FlowManager, - FlowResult, -) +from homeassistant.data_entry_flow import AbortFlow, BaseFlowManager +from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers.aiohttp_client import async_get_clientsession from . import disconnect_client @@ -156,7 +162,7 @@ async def async_get_usb_ports(hass: HomeAssistant) -> dict[str, str]: return await hass.async_add_executor_job(get_usb_ports) -class BaseZwaveJSFlow(FlowHandler, ABC): +class BaseZwaveJSFlow(ConfigEntryBaseFlow, ABC): """Represent the base config flow for Z-Wave JS.""" def __init__(self) -> None: @@ -176,12 +182,12 @@ class BaseZwaveJSFlow(FlowHandler, ABC): @property @abstractmethod - def flow_manager(self) -> FlowManager: + def flow_manager(self) -> BaseFlowManager: """Return the flow manager of the flow.""" async def async_step_install_addon( self, user_input: dict[str, Any] | None = None - ) -> FlowResult: + ) -> ConfigFlowResult: """Install Z-Wave JS add-on.""" if not self.install_task: self.install_task = self.hass.async_create_task(self._async_install_addon()) @@ -207,13 +213,13 @@ class BaseZwaveJSFlow(FlowHandler, ABC): async def async_step_install_failed( self, user_input: dict[str, Any] | None = None - ) -> FlowResult: + ) -> ConfigFlowResult: """Add-on installation failed.""" return self.async_abort(reason="addon_install_failed") async def async_step_start_addon( self, user_input: dict[str, Any] | None = None - ) -> FlowResult: + ) -> ConfigFlowResult: """Start Z-Wave JS add-on.""" if not self.start_task: self.start_task = self.hass.async_create_task(self._async_start_addon()) @@ -237,7 +243,7 @@ class BaseZwaveJSFlow(FlowHandler, ABC): async def async_step_start_failed( self, user_input: dict[str, Any] | None = None - ) -> FlowResult: + ) -> ConfigFlowResult: """Add-on start failed.""" return self.async_abort(reason="addon_start_failed") @@ -275,13 +281,13 @@ class BaseZwaveJSFlow(FlowHandler, ABC): @abstractmethod async def async_step_configure_addon( self, user_input: dict[str, Any] | None = None - ) -> FlowResult: + ) -> ConfigFlowResult: """Ask for config for Z-Wave JS add-on.""" @abstractmethod async def async_step_finish_addon_setup( self, user_input: dict[str, Any] | None = None - ) -> FlowResult: + ) -> ConfigFlowResult: """Prepare info needed to complete the config entry. Get add-on discovery info and server version info. @@ -325,7 +331,7 @@ class BaseZwaveJSFlow(FlowHandler, ABC): return discovery_info_config -class ConfigFlow(BaseZwaveJSFlow, config_entries.ConfigFlow, domain=DOMAIN): +class ZWaveJSConfigFlow(BaseZwaveJSFlow, ConfigFlow, domain=DOMAIN): """Handle a config flow for Z-Wave JS.""" VERSION = 1 @@ -338,19 +344,19 @@ class ConfigFlow(BaseZwaveJSFlow, config_entries.ConfigFlow, domain=DOMAIN): self._usb_discovery = False @property - def flow_manager(self) -> config_entries.ConfigEntriesFlowManager: + def flow_manager(self) -> ConfigEntriesFlowManager: """Return the correct flow manager.""" return self.hass.config_entries.flow @staticmethod @callback def async_get_options_flow( - config_entry: config_entries.ConfigEntry, + config_entry: ConfigEntry, ) -> OptionsFlowHandler: """Return the options flow.""" return OptionsFlowHandler(config_entry) - async def async_step_import(self, data: dict[str, Any]) -> FlowResult: + async def async_step_import(self, data: dict[str, Any]) -> ConfigFlowResult: """Handle imported data. This step will be used when importing data @@ -364,7 +370,7 @@ class ConfigFlow(BaseZwaveJSFlow, config_entries.ConfigFlow, domain=DOMAIN): async def async_step_user( self, user_input: dict[str, Any] | None = None - ) -> FlowResult: + ) -> ConfigFlowResult: """Handle the initial step.""" if is_hassio(self.hass): return await self.async_step_on_supervisor() @@ -373,7 +379,7 @@ class ConfigFlow(BaseZwaveJSFlow, config_entries.ConfigFlow, domain=DOMAIN): async def async_step_zeroconf( self, discovery_info: ZeroconfServiceInfo - ) -> FlowResult: + ) -> ConfigFlowResult: """Handle zeroconf discovery.""" home_id = str(discovery_info.properties["homeId"]) await self.async_set_unique_id(home_id) @@ -384,7 +390,7 @@ class ConfigFlow(BaseZwaveJSFlow, config_entries.ConfigFlow, domain=DOMAIN): async def async_step_zeroconf_confirm( self, user_input: dict | None = None - ) -> FlowResult: + ) -> ConfigFlowResult: """Confirm the setup.""" if user_input is not None: return await self.async_step_manual({CONF_URL: self.ws_address}) @@ -398,7 +404,9 @@ class ConfigFlow(BaseZwaveJSFlow, config_entries.ConfigFlow, domain=DOMAIN): }, ) - async def async_step_usb(self, discovery_info: usb.UsbServiceInfo) -> FlowResult: + async def async_step_usb( + self, discovery_info: usb.UsbServiceInfo + ) -> ConfigFlowResult: """Handle USB Discovery.""" if not is_hassio(self.hass): return self.async_abort(reason="discovery_requires_supervisor") @@ -441,7 +449,7 @@ class ConfigFlow(BaseZwaveJSFlow, config_entries.ConfigFlow, domain=DOMAIN): async def async_step_usb_confirm( self, user_input: dict[str, Any] | None = None - ) -> FlowResult: + ) -> ConfigFlowResult: """Handle USB Discovery confirmation.""" if user_input is None: return self.async_show_form( @@ -455,7 +463,7 @@ class ConfigFlow(BaseZwaveJSFlow, config_entries.ConfigFlow, domain=DOMAIN): async def async_step_manual( self, user_input: dict[str, Any] | None = None - ) -> FlowResult: + ) -> ConfigFlowResult: """Handle a manual configuration.""" if user_input is None: return self.async_show_form( @@ -491,7 +499,9 @@ class ConfigFlow(BaseZwaveJSFlow, config_entries.ConfigFlow, domain=DOMAIN): step_id="manual", data_schema=get_manual_schema(user_input), errors=errors ) - async def async_step_hassio(self, discovery_info: HassioServiceInfo) -> FlowResult: + async def async_step_hassio( + self, discovery_info: HassioServiceInfo + ) -> ConfigFlowResult: """Receive configuration from add-on discovery info. This flow is triggered by the Z-Wave JS add-on. @@ -517,7 +527,7 @@ class ConfigFlow(BaseZwaveJSFlow, config_entries.ConfigFlow, domain=DOMAIN): async def async_step_hassio_confirm( self, user_input: dict[str, Any] | None = None - ) -> FlowResult: + ) -> ConfigFlowResult: """Confirm the add-on discovery.""" if user_input is not None: return await self.async_step_on_supervisor( @@ -528,7 +538,7 @@ class ConfigFlow(BaseZwaveJSFlow, config_entries.ConfigFlow, domain=DOMAIN): async def async_step_on_supervisor( self, user_input: dict[str, Any] | None = None - ) -> FlowResult: + ) -> ConfigFlowResult: """Handle logic when on Supervisor host.""" if user_input is None: return self.async_show_form( @@ -563,7 +573,7 @@ class ConfigFlow(BaseZwaveJSFlow, config_entries.ConfigFlow, domain=DOMAIN): async def async_step_configure_addon( self, user_input: dict[str, Any] | None = None - ) -> FlowResult: + ) -> ConfigFlowResult: """Ask for config for Z-Wave JS add-on.""" addon_info = await self._async_get_addon_info() addon_config = addon_info.options @@ -628,7 +638,7 @@ class ConfigFlow(BaseZwaveJSFlow, config_entries.ConfigFlow, domain=DOMAIN): async def async_step_finish_addon_setup( self, user_input: dict[str, Any] | None = None - ) -> FlowResult: + ) -> ConfigFlowResult: """Prepare info needed to complete the config entry. Get add-on discovery info and server version info. @@ -638,7 +648,7 @@ class ConfigFlow(BaseZwaveJSFlow, config_entries.ConfigFlow, domain=DOMAIN): discovery_info = await self._async_get_addon_discovery_info() self.ws_address = f"ws://{discovery_info['host']}:{discovery_info['port']}" - if not self.unique_id or self.context["source"] == config_entries.SOURCE_USB: + if not self.unique_id or self.context["source"] == SOURCE_USB: if not self.version_info: try: self.version_info = await async_get_version_info( @@ -664,7 +674,7 @@ class ConfigFlow(BaseZwaveJSFlow, config_entries.ConfigFlow, domain=DOMAIN): return self._async_create_entry_from_vars() @callback - def _async_create_entry_from_vars(self) -> FlowResult: + def _async_create_entry_from_vars(self) -> ConfigFlowResult: """Return a config entry for the flow.""" # Abort any other flows that may be in progress for progress in self._async_in_progress(): @@ -685,10 +695,10 @@ class ConfigFlow(BaseZwaveJSFlow, config_entries.ConfigFlow, domain=DOMAIN): ) -class OptionsFlowHandler(BaseZwaveJSFlow, config_entries.OptionsFlow): +class OptionsFlowHandler(BaseZwaveJSFlow, OptionsFlow): """Handle an options flow for Z-Wave JS.""" - def __init__(self, config_entry: config_entries.ConfigEntry) -> None: + def __init__(self, config_entry: ConfigEntry) -> None: """Set up the options flow.""" super().__init__() self.config_entry = config_entry @@ -696,7 +706,7 @@ class OptionsFlowHandler(BaseZwaveJSFlow, config_entries.OptionsFlow): self.revert_reason: str | None = None @property - def flow_manager(self) -> config_entries.OptionsFlowManager: + def flow_manager(self) -> OptionsFlowManager: """Return the correct flow manager.""" return self.hass.config_entries.options @@ -707,7 +717,7 @@ class OptionsFlowHandler(BaseZwaveJSFlow, config_entries.OptionsFlow): async def async_step_init( self, user_input: dict[str, Any] | None = None - ) -> FlowResult: + ) -> ConfigFlowResult: """Manage the options.""" if is_hassio(self.hass): return await self.async_step_on_supervisor() @@ -716,7 +726,7 @@ class OptionsFlowHandler(BaseZwaveJSFlow, config_entries.OptionsFlow): async def async_step_manual( self, user_input: dict[str, Any] | None = None - ) -> FlowResult: + ) -> ConfigFlowResult: """Handle a manual configuration.""" if user_input is None: return self.async_show_form( @@ -759,7 +769,7 @@ class OptionsFlowHandler(BaseZwaveJSFlow, config_entries.OptionsFlow): async def async_step_on_supervisor( self, user_input: dict[str, Any] | None = None - ) -> FlowResult: + ) -> ConfigFlowResult: """Handle logic when on Supervisor host.""" if user_input is None: return self.async_show_form( @@ -780,7 +790,7 @@ class OptionsFlowHandler(BaseZwaveJSFlow, config_entries.OptionsFlow): async def async_step_configure_addon( self, user_input: dict[str, Any] | None = None - ) -> FlowResult: + ) -> ConfigFlowResult: """Ask for config for Z-Wave JS add-on.""" addon_info = await self._async_get_addon_info() addon_config = addon_info.options @@ -819,7 +829,7 @@ class OptionsFlowHandler(BaseZwaveJSFlow, config_entries.OptionsFlow): if ( self.config_entry.data.get(CONF_USE_ADDON) - and self.config_entry.state == config_entries.ConfigEntryState.LOADED + and self.config_entry.state == ConfigEntryState.LOADED ): # Disconnect integration before restarting add-on. await disconnect_client(self.hass, self.config_entry) @@ -868,13 +878,13 @@ class OptionsFlowHandler(BaseZwaveJSFlow, config_entries.OptionsFlow): async def async_step_start_failed( self, user_input: dict[str, Any] | None = None - ) -> FlowResult: + ) -> ConfigFlowResult: """Add-on start failed.""" return await self.async_revert_addon_config(reason="addon_start_failed") async def async_step_finish_addon_setup( self, user_input: dict[str, Any] | None = None - ) -> FlowResult: + ) -> ConfigFlowResult: """Prepare info needed to complete the config entry update. Get add-on discovery info and server version info. @@ -918,7 +928,7 @@ class OptionsFlowHandler(BaseZwaveJSFlow, config_entries.OptionsFlow): self.hass.config_entries.async_schedule_reload(self.config_entry.entry_id) return self.async_create_entry(title=TITLE, data={}) - async def async_revert_addon_config(self, reason: str) -> FlowResult: + async def async_revert_addon_config(self, reason: str) -> ConfigFlowResult: """Abort the options flow. If the add-on options have been changed, revert those and restart add-on. @@ -944,11 +954,11 @@ class OptionsFlowHandler(BaseZwaveJSFlow, config_entries.OptionsFlow): return await self.async_step_configure_addon(addon_config_input) -class CannotConnect(exceptions.HomeAssistantError): +class CannotConnect(HomeAssistantError): """Indicate connection error.""" -class InvalidInput(exceptions.HomeAssistantError): +class InvalidInput(HomeAssistantError): """Error to indicate input data is invalid.""" def __init__(self, error: str) -> None: diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index 1ca40886da2..2200831e576 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -242,6 +242,9 @@ UPDATE_ENTRY_CONFIG_ENTRY_ATTRS = { } +ConfigFlowResult = FlowResult + + class ConfigEntry: """Hold a configuration entry.""" @@ -903,7 +906,7 @@ class ConfigEntry: @callback def async_get_active_flows( self, hass: HomeAssistant, sources: set[str] - ) -> Generator[FlowResult, None, None]: + ) -> Generator[ConfigFlowResult, None, None]: """Get any active flows of certain sources for this entry.""" return ( flow @@ -970,9 +973,11 @@ class FlowCancelledError(Exception): """Error to indicate that a flow has been cancelled.""" -class ConfigEntriesFlowManager(data_entry_flow.FlowManager): +class ConfigEntriesFlowManager(data_entry_flow.BaseFlowManager[ConfigFlowResult]): """Manage all the config entry flows that are in progress.""" + _flow_result = ConfigFlowResult + def __init__( self, hass: HomeAssistant, @@ -1010,7 +1015,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager): async def async_init( self, handler: str, *, context: dict[str, Any] | None = None, data: Any = None - ) -> FlowResult: + ) -> ConfigFlowResult: """Start a configuration flow.""" if not context or "source" not in context: raise KeyError("Context not set or doesn't have a source set") @@ -1024,7 +1029,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager): and await _support_single_config_entry_only(self.hass, handler) and self.config_entries.async_entries(handler, include_ignore=False) ): - return FlowResult( + return ConfigFlowResult( type=data_entry_flow.FlowResultType.ABORT, flow_id=flow_id, handler=handler, @@ -1065,7 +1070,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager): handler: str, context: dict, data: Any, - ) -> tuple[data_entry_flow.FlowHandler, FlowResult]: + ) -> tuple[ConfigFlow, ConfigFlowResult]: """Run the init in a task to allow it to be canceled at shutdown.""" flow = await self.async_create_flow(handler, context=context, data=data) if not flow: @@ -1093,8 +1098,8 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager): self._discovery_debouncer.async_shutdown() async def async_finish_flow( - self, flow: data_entry_flow.FlowHandler, result: data_entry_flow.FlowResult - ) -> data_entry_flow.FlowResult: + self, flow: data_entry_flow.BaseFlowHandler, result: ConfigFlowResult + ) -> ConfigFlowResult: """Finish a config flow and add an entry.""" flow = cast(ConfigFlow, flow) @@ -1128,7 +1133,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager): and flow.context["source"] != SOURCE_IGNORE and self.config_entries.async_entries(flow.handler, include_ignore=False) ): - return FlowResult( + return ConfigFlowResult( type=data_entry_flow.FlowResultType.ABORT, flow_id=flow.flow_id, handler=flow.handler, @@ -1213,7 +1218,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager): return flow async def async_post_init( - self, flow: data_entry_flow.FlowHandler, result: data_entry_flow.FlowResult + self, flow: data_entry_flow.BaseFlowHandler, result: ConfigFlowResult ) -> None: """After a flow is initialised trigger new flow notifications.""" source = flow.context["source"] @@ -1852,7 +1857,13 @@ def _async_abort_entries_match( raise data_entry_flow.AbortFlow("already_configured") -class ConfigFlow(data_entry_flow.FlowHandler): +class ConfigEntryBaseFlow(data_entry_flow.BaseFlowHandler[ConfigFlowResult]): + """Base class for config and option flows.""" + + _flow_result = ConfigFlowResult + + +class ConfigFlow(ConfigEntryBaseFlow): """Base class for config flows with some helpers.""" def __init_subclass__(cls, *, domain: str | None = None, **kwargs: Any) -> None: @@ -2008,7 +2019,7 @@ class ConfigFlow(data_entry_flow.FlowHandler): self, include_uninitialized: bool = False, match_context: dict[str, Any] | None = None, - ) -> list[data_entry_flow.FlowResult]: + ) -> list[ConfigFlowResult]: """Return other in progress flows for current domain.""" return [ flw @@ -2020,22 +2031,18 @@ class ConfigFlow(data_entry_flow.FlowHandler): if flw["flow_id"] != self.flow_id ] - async def async_step_ignore( - self, user_input: dict[str, Any] - ) -> data_entry_flow.FlowResult: + async def async_step_ignore(self, user_input: dict[str, Any]) -> ConfigFlowResult: """Ignore this config flow.""" await self.async_set_unique_id(user_input["unique_id"], raise_on_progress=False) return self.async_create_entry(title=user_input["title"], data={}) - async def async_step_unignore( - self, user_input: dict[str, Any] - ) -> data_entry_flow.FlowResult: + async def async_step_unignore(self, user_input: dict[str, Any]) -> ConfigFlowResult: """Rediscover a config entry by it's unique_id.""" return self.async_abort(reason="not_implemented") async def async_step_user( self, user_input: dict[str, Any] | None = None - ) -> data_entry_flow.FlowResult: + ) -> ConfigFlowResult: """Handle a flow initiated by the user.""" return self.async_abort(reason="not_implemented") @@ -2068,14 +2075,14 @@ class ConfigFlow(data_entry_flow.FlowHandler): async def _async_step_discovery_without_unique_id( self, - ) -> data_entry_flow.FlowResult: + ) -> ConfigFlowResult: """Handle a flow initialized by discovery.""" await self._async_handle_discovery_without_unique_id() return await self.async_step_user() async def async_step_discovery( self, discovery_info: DiscoveryInfoType - ) -> data_entry_flow.FlowResult: + ) -> ConfigFlowResult: """Handle a flow initialized by discovery.""" return await self._async_step_discovery_without_unique_id() @@ -2085,7 +2092,7 @@ class ConfigFlow(data_entry_flow.FlowHandler): *, reason: str, description_placeholders: Mapping[str, str] | None = None, - ) -> data_entry_flow.FlowResult: + ) -> ConfigFlowResult: """Abort the config flow.""" # Remove reauth notification if no reauth flows are in progress if self.source == SOURCE_REAUTH and not any( @@ -2104,55 +2111,53 @@ class ConfigFlow(data_entry_flow.FlowHandler): async def async_step_bluetooth( self, discovery_info: BluetoothServiceInfoBleak - ) -> data_entry_flow.FlowResult: + ) -> ConfigFlowResult: """Handle a flow initialized by Bluetooth discovery.""" return await self._async_step_discovery_without_unique_id() async def async_step_dhcp( self, discovery_info: DhcpServiceInfo - ) -> data_entry_flow.FlowResult: + ) -> ConfigFlowResult: """Handle a flow initialized by DHCP discovery.""" return await self._async_step_discovery_without_unique_id() async def async_step_hassio( self, discovery_info: HassioServiceInfo - ) -> data_entry_flow.FlowResult: + ) -> ConfigFlowResult: """Handle a flow initialized by HASS IO discovery.""" return await self._async_step_discovery_without_unique_id() async def async_step_integration_discovery( self, discovery_info: DiscoveryInfoType - ) -> data_entry_flow.FlowResult: + ) -> ConfigFlowResult: """Handle a flow initialized by integration specific discovery.""" return await self._async_step_discovery_without_unique_id() async def async_step_homekit( self, discovery_info: ZeroconfServiceInfo - ) -> data_entry_flow.FlowResult: + ) -> ConfigFlowResult: """Handle a flow initialized by Homekit discovery.""" return await self._async_step_discovery_without_unique_id() async def async_step_mqtt( self, discovery_info: MqttServiceInfo - ) -> data_entry_flow.FlowResult: + ) -> ConfigFlowResult: """Handle a flow initialized by MQTT discovery.""" return await self._async_step_discovery_without_unique_id() async def async_step_ssdp( self, discovery_info: SsdpServiceInfo - ) -> data_entry_flow.FlowResult: + ) -> ConfigFlowResult: """Handle a flow initialized by SSDP discovery.""" return await self._async_step_discovery_without_unique_id() - async def async_step_usb( - self, discovery_info: UsbServiceInfo - ) -> data_entry_flow.FlowResult: + async def async_step_usb(self, discovery_info: UsbServiceInfo) -> ConfigFlowResult: """Handle a flow initialized by USB discovery.""" return await self._async_step_discovery_without_unique_id() async def async_step_zeroconf( self, discovery_info: ZeroconfServiceInfo - ) -> data_entry_flow.FlowResult: + ) -> ConfigFlowResult: """Handle a flow initialized by Zeroconf discovery.""" return await self._async_step_discovery_without_unique_id() @@ -2165,7 +2170,7 @@ class ConfigFlow(data_entry_flow.FlowHandler): description: str | None = None, description_placeholders: Mapping[str, str] | None = None, options: Mapping[str, Any] | None = None, - ) -> data_entry_flow.FlowResult: + ) -> ConfigFlowResult: """Finish config flow and create a config entry.""" result = super().async_create_entry( title=title, @@ -2175,6 +2180,8 @@ class ConfigFlow(data_entry_flow.FlowHandler): ) result["options"] = options or {} + result["minor_version"] = self.MINOR_VERSION + result["version"] = self.VERSION return result @@ -2188,7 +2195,7 @@ class ConfigFlow(data_entry_flow.FlowHandler): data: Mapping[str, Any] | UndefinedType = UNDEFINED, options: Mapping[str, Any] | UndefinedType = UNDEFINED, reason: str = "reauth_successful", - ) -> data_entry_flow.FlowResult: + ) -> ConfigFlowResult: """Update config entry, reload config entry and finish config flow.""" result = self.hass.config_entries.async_update_entry( entry=entry, @@ -2202,9 +2209,11 @@ class ConfigFlow(data_entry_flow.FlowHandler): return self.async_abort(reason=reason) -class OptionsFlowManager(data_entry_flow.FlowManager): +class OptionsFlowManager(data_entry_flow.BaseFlowManager[ConfigFlowResult]): """Flow to set options for a configuration entry.""" + _flow_result = ConfigFlowResult + def _async_get_config_entry(self, config_entry_id: str) -> ConfigEntry: """Return config entry or raise if not found.""" entry = self.hass.config_entries.async_get_entry(config_entry_id) @@ -2229,8 +2238,8 @@ class OptionsFlowManager(data_entry_flow.FlowManager): return handler.async_get_options_flow(entry) async def async_finish_flow( - self, flow: data_entry_flow.FlowHandler, result: data_entry_flow.FlowResult - ) -> data_entry_flow.FlowResult: + self, flow: data_entry_flow.BaseFlowHandler, result: ConfigFlowResult + ) -> ConfigFlowResult: """Finish an options flow and update options for configuration entry. Flow.handler and entry_id is the same thing to map flow with entry. @@ -2249,7 +2258,7 @@ class OptionsFlowManager(data_entry_flow.FlowManager): result["result"] = True return result - async def _async_setup_preview(self, flow: data_entry_flow.FlowHandler) -> None: + async def _async_setup_preview(self, flow: data_entry_flow.BaseFlowHandler) -> None: """Set up preview for an option flow handler.""" entry = self._async_get_config_entry(flow.handler) await _load_integration(self.hass, entry.domain, {}) @@ -2258,7 +2267,7 @@ class OptionsFlowManager(data_entry_flow.FlowManager): await flow.async_setup_preview(self.hass) -class OptionsFlow(data_entry_flow.FlowHandler): +class OptionsFlow(ConfigEntryBaseFlow): """Base class for config options flows.""" handler: str diff --git a/homeassistant/data_entry_flow.py b/homeassistant/data_entry_flow.py index bbb6621cfcc..b573f528945 100644 --- a/homeassistant/data_entry_flow.py +++ b/homeassistant/data_entry_flow.py @@ -11,7 +11,7 @@ from enum import StrEnum from functools import partial import logging from types import MappingProxyType -from typing import Any, Required, TypedDict +from typing import Any, Generic, Required, TypedDict, TypeVar import voluptuous as vol @@ -75,6 +75,7 @@ FLOW_NOT_COMPLETE_STEPS = { FlowResultType.MENU, } + STEP_ID_OPTIONAL_STEPS = { FlowResultType.EXTERNAL_STEP, FlowResultType.FORM, @@ -83,6 +84,9 @@ STEP_ID_OPTIONAL_STEPS = { } +_FlowResultT = TypeVar("_FlowResultT", bound="FlowResult") + + @dataclass(slots=True) class BaseServiceInfo: """Base class for discovery ServiceInfo.""" @@ -163,26 +167,6 @@ class FlowResult(TypedDict, total=False): version: int -@callback -def _async_flow_handler_to_flow_result( - flows: Iterable[FlowHandler], include_uninitialized: bool -) -> list[FlowResult]: - """Convert a list of FlowHandler to a partial FlowResult that can be serialized.""" - results = [] - for flow in flows: - if not include_uninitialized and flow.cur_step is None: - continue - result = FlowResult( - flow_id=flow.flow_id, - handler=flow.handler, - context=flow.context, - ) - if flow.cur_step: - result["step_id"] = flow.cur_step["step_id"] - results.append(result) - return results - - def _map_error_to_schema_errors( schema_errors: dict[str, Any], error: vol.Invalid, @@ -206,9 +190,11 @@ def _map_error_to_schema_errors( schema_errors[path_part_str] = error.error_message -class FlowManager(abc.ABC): +class BaseFlowManager(abc.ABC, Generic[_FlowResultT]): """Manage all the flows that are in progress.""" + _flow_result: Callable[..., _FlowResultT] + def __init__( self, hass: HomeAssistant, @@ -216,9 +202,9 @@ class FlowManager(abc.ABC): """Initialize the flow manager.""" self.hass = hass self._preview: set[str] = set() - self._progress: dict[str, FlowHandler] = {} - self._handler_progress_index: dict[str, set[FlowHandler]] = {} - self._init_data_process_index: dict[type, set[FlowHandler]] = {} + self._progress: dict[str, BaseFlowHandler] = {} + self._handler_progress_index: dict[str, set[BaseFlowHandler]] = {} + self._init_data_process_index: dict[type, set[BaseFlowHandler]] = {} @abc.abstractmethod async def async_create_flow( @@ -227,7 +213,7 @@ class FlowManager(abc.ABC): *, context: dict[str, Any] | None = None, data: dict[str, Any] | None = None, - ) -> FlowHandler: + ) -> BaseFlowHandler[_FlowResultT]: """Create a flow for specified handler. Handler key is the domain of the component that we want to set up. @@ -235,11 +221,13 @@ class FlowManager(abc.ABC): @abc.abstractmethod async def async_finish_flow( - self, flow: FlowHandler, result: FlowResult - ) -> FlowResult: + self, flow: BaseFlowHandler, result: _FlowResultT + ) -> _FlowResultT: """Finish a data entry flow.""" - async def async_post_init(self, flow: FlowHandler, result: FlowResult) -> None: + async def async_post_init( + self, flow: BaseFlowHandler, result: _FlowResultT + ) -> None: """Entry has finished executing its first step asynchronously.""" @callback @@ -262,16 +250,16 @@ class FlowManager(abc.ABC): return False @callback - def async_get(self, flow_id: str) -> FlowResult: + def async_get(self, flow_id: str) -> _FlowResultT: """Return a flow in progress as a partial FlowResult.""" if (flow := self._progress.get(flow_id)) is None: raise UnknownFlow - return _async_flow_handler_to_flow_result([flow], False)[0] + return self._async_flow_handler_to_flow_result([flow], False)[0] @callback - def async_progress(self, include_uninitialized: bool = False) -> list[FlowResult]: + def async_progress(self, include_uninitialized: bool = False) -> list[_FlowResultT]: """Return the flows in progress as a partial FlowResult.""" - return _async_flow_handler_to_flow_result( + return self._async_flow_handler_to_flow_result( self._progress.values(), include_uninitialized ) @@ -281,13 +269,13 @@ class FlowManager(abc.ABC): handler: str, include_uninitialized: bool = False, match_context: dict[str, Any] | None = None, - ) -> list[FlowResult]: + ) -> list[_FlowResultT]: """Return the flows in progress by handler as a partial FlowResult. If match_context is specified, only return flows with a context that is a superset of match_context. """ - return _async_flow_handler_to_flow_result( + return self._async_flow_handler_to_flow_result( self._async_progress_by_handler(handler, match_context), include_uninitialized, ) @@ -298,9 +286,9 @@ class FlowManager(abc.ABC): init_data_type: type, matcher: Callable[[Any], bool], include_uninitialized: bool = False, - ) -> list[FlowResult]: + ) -> list[_FlowResultT]: """Return flows in progress init matching by data type as a partial FlowResult.""" - return _async_flow_handler_to_flow_result( + return self._async_flow_handler_to_flow_result( ( progress for progress in self._init_data_process_index.get(init_data_type, set()) @@ -312,7 +300,7 @@ class FlowManager(abc.ABC): @callback def _async_progress_by_handler( self, handler: str, match_context: dict[str, Any] | None - ) -> list[FlowHandler]: + ) -> list[BaseFlowHandler[_FlowResultT]]: """Return the flows in progress by handler. If match_context is specified, only return flows with a context that @@ -329,7 +317,7 @@ class FlowManager(abc.ABC): async def async_init( self, handler: str, *, context: dict[str, Any] | None = None, data: Any = None - ) -> FlowResult: + ) -> _FlowResultT: """Start a data entry flow.""" if context is None: context = {} @@ -352,9 +340,9 @@ class FlowManager(abc.ABC): async def async_configure( self, flow_id: str, user_input: dict | None = None - ) -> FlowResult: + ) -> _FlowResultT: """Continue a data entry flow.""" - result: FlowResult | None = None + result: _FlowResultT | None = None while not result or result["type"] == FlowResultType.SHOW_PROGRESS_DONE: result = await self._async_configure(flow_id, user_input) flow = self._progress.get(flow_id) @@ -364,7 +352,7 @@ class FlowManager(abc.ABC): async def _async_configure( self, flow_id: str, user_input: dict | None = None - ) -> FlowResult: + ) -> _FlowResultT: """Continue a data entry flow.""" if (flow := self._progress.get(flow_id)) is None: raise UnknownFlow @@ -458,7 +446,7 @@ class FlowManager(abc.ABC): self._async_remove_flow_progress(flow_id) @callback - def _async_add_flow_progress(self, flow: FlowHandler) -> None: + def _async_add_flow_progress(self, flow: BaseFlowHandler[_FlowResultT]) -> None: """Add a flow to in progress.""" if flow.init_data is not None: init_data_type = type(flow.init_data) @@ -467,7 +455,9 @@ class FlowManager(abc.ABC): self._handler_progress_index.setdefault(flow.handler, set()).add(flow) @callback - def _async_remove_flow_from_index(self, flow: FlowHandler) -> None: + def _async_remove_flow_from_index( + self, flow: BaseFlowHandler[_FlowResultT] + ) -> None: """Remove a flow from in progress.""" if flow.init_data is not None: init_data_type = type(flow.init_data) @@ -492,17 +482,24 @@ class FlowManager(abc.ABC): _LOGGER.exception("Error removing %s flow: %s", flow.handler, err) async def _async_handle_step( - self, flow: FlowHandler, step_id: str, user_input: dict | BaseServiceInfo | None - ) -> FlowResult: + self, + flow: BaseFlowHandler[_FlowResultT], + step_id: str, + user_input: dict | BaseServiceInfo | None, + ) -> _FlowResultT: """Handle a step of a flow.""" self._raise_if_step_does_not_exist(flow, step_id) method = f"async_step_{step_id}" try: - result: FlowResult = await getattr(flow, method)(user_input) + result: _FlowResultT = await getattr(flow, method)(user_input) except AbortFlow as err: - result = _create_abort_data( - flow.flow_id, flow.handler, err.reason, err.description_placeholders + result = self._flow_result( + type=FlowResultType.ABORT, + flow_id=flow.flow_id, + handler=flow.handler, + reason=err.reason, + description_placeholders=err.description_placeholders, ) # Setup the flow handler's preview if needed @@ -521,7 +518,8 @@ class FlowManager(abc.ABC): if ( result["type"] == FlowResultType.SHOW_PROGRESS - and (progress_task := result.pop("progress_task", None)) + # Mypy does not agree with using pop on _FlowResultT + and (progress_task := result.pop("progress_task", None)) # type: ignore[arg-type] and progress_task != flow.async_get_progress_task() ): # The flow's progress task was changed, register a callback on it @@ -532,8 +530,9 @@ class FlowManager(abc.ABC): def schedule_configure(_: asyncio.Task) -> None: self.hass.async_create_task(call_configure()) - progress_task.add_done_callback(schedule_configure) - flow.async_set_progress_task(progress_task) + # The mypy ignores are a consequence of mypy not accepting the pop above + progress_task.add_done_callback(schedule_configure) # type: ignore[attr-defined] + flow.async_set_progress_task(progress_task) # type: ignore[arg-type] elif result["type"] != FlowResultType.SHOW_PROGRESS: flow.async_cancel_progress_task() @@ -560,7 +559,9 @@ class FlowManager(abc.ABC): return result - def _raise_if_step_does_not_exist(self, flow: FlowHandler, step_id: str) -> None: + def _raise_if_step_does_not_exist( + self, flow: BaseFlowHandler, step_id: str + ) -> None: """Raise if the step does not exist.""" method = f"async_step_{step_id}" @@ -570,18 +571,45 @@ class FlowManager(abc.ABC): f"Handler {self.__class__.__name__} doesn't support step {step_id}" ) - async def _async_setup_preview(self, flow: FlowHandler) -> None: + async def _async_setup_preview(self, flow: BaseFlowHandler) -> None: """Set up preview for a flow handler.""" if flow.handler not in self._preview: self._preview.add(flow.handler) await flow.async_setup_preview(self.hass) + @callback + def _async_flow_handler_to_flow_result( + self, flows: Iterable[BaseFlowHandler], include_uninitialized: bool + ) -> list[_FlowResultT]: + """Convert a list of FlowHandler to a partial FlowResult that can be serialized.""" + results = [] + for flow in flows: + if not include_uninitialized and flow.cur_step is None: + continue + result = self._flow_result( + flow_id=flow.flow_id, + handler=flow.handler, + context=flow.context, + ) + if flow.cur_step: + result["step_id"] = flow.cur_step["step_id"] + results.append(result) + return results -class FlowHandler: + +class FlowManager(BaseFlowManager[FlowResult]): + """Manage all the flows that are in progress.""" + + _flow_result = FlowResult + + +class BaseFlowHandler(Generic[_FlowResultT]): """Handle a data entry flow.""" + _flow_result: Callable[..., _FlowResultT] + # Set by flow manager - cur_step: FlowResult | None = None + cur_step: _FlowResultT | None = None # While not purely typed, it makes typehinting more useful for us # and removes the need for constant None checks or asserts. @@ -657,12 +685,12 @@ class FlowHandler: description_placeholders: Mapping[str, str | None] | None = None, last_step: bool | None = None, preview: str | None = None, - ) -> FlowResult: + ) -> _FlowResultT: """Return the definition of a form to gather user input. The step_id parameter is deprecated and will be removed in a future release. """ - flow_result = FlowResult( + flow_result = self._flow_result( type=FlowResultType.FORM, flow_id=self.flow_id, handler=self.handler, @@ -684,11 +712,9 @@ class FlowHandler: data: Mapping[str, Any], description: str | None = None, description_placeholders: Mapping[str, str] | None = None, - ) -> FlowResult: + ) -> _FlowResultT: """Finish flow.""" - flow_result = FlowResult( - version=self.VERSION, - minor_version=self.MINOR_VERSION, + flow_result = self._flow_result( type=FlowResultType.CREATE_ENTRY, flow_id=self.flow_id, handler=self.handler, @@ -707,10 +733,14 @@ class FlowHandler: *, reason: str, description_placeholders: Mapping[str, str] | None = None, - ) -> FlowResult: + ) -> _FlowResultT: """Abort the flow.""" - return _create_abort_data( - self.flow_id, self.handler, reason, description_placeholders + return self._flow_result( + type=FlowResultType.ABORT, + flow_id=self.flow_id, + handler=self.handler, + reason=reason, + description_placeholders=description_placeholders, ) @callback @@ -720,12 +750,12 @@ class FlowHandler: step_id: str | None = None, url: str, description_placeholders: Mapping[str, str] | None = None, - ) -> FlowResult: + ) -> _FlowResultT: """Return the definition of an external step for the user to take. The step_id parameter is deprecated and will be removed in a future release. """ - flow_result = FlowResult( + flow_result = self._flow_result( type=FlowResultType.EXTERNAL_STEP, flow_id=self.flow_id, handler=self.handler, @@ -737,9 +767,9 @@ class FlowHandler: return flow_result @callback - def async_external_step_done(self, *, next_step_id: str) -> FlowResult: + def async_external_step_done(self, *, next_step_id: str) -> _FlowResultT: """Return the definition of an external step for the user to take.""" - return FlowResult( + return self._flow_result( type=FlowResultType.EXTERNAL_STEP_DONE, flow_id=self.flow_id, handler=self.handler, @@ -754,7 +784,7 @@ class FlowHandler: progress_action: str, description_placeholders: Mapping[str, str] | None = None, progress_task: asyncio.Task[Any] | None = None, - ) -> FlowResult: + ) -> _FlowResultT: """Show a progress message to the user, without user input allowed. The step_id parameter is deprecated and will be removed in a future release. @@ -777,7 +807,7 @@ class FlowHandler: if progress_task is None: self.deprecated_show_progress = True - flow_result = FlowResult( + flow_result = self._flow_result( type=FlowResultType.SHOW_PROGRESS, flow_id=self.flow_id, handler=self.handler, @@ -790,9 +820,9 @@ class FlowHandler: return flow_result @callback - def async_show_progress_done(self, *, next_step_id: str) -> FlowResult: + def async_show_progress_done(self, *, next_step_id: str) -> _FlowResultT: """Mark the progress done.""" - return FlowResult( + return self._flow_result( type=FlowResultType.SHOW_PROGRESS_DONE, flow_id=self.flow_id, handler=self.handler, @@ -806,13 +836,13 @@ class FlowHandler: step_id: str | None = None, menu_options: list[str] | dict[str, str], description_placeholders: Mapping[str, str] | None = None, - ) -> FlowResult: + ) -> _FlowResultT: """Show a navigation menu to the user. Options dict maps step_id => i18n label The step_id parameter is deprecated and will be removed in a future release. """ - flow_result = FlowResult( + flow_result = self._flow_result( type=FlowResultType.MENU, flow_id=self.flow_id, handler=self.handler, @@ -853,21 +883,10 @@ class FlowHandler: self.__progress_task = progress_task -@callback -def _create_abort_data( - flow_id: str, - handler: str, - reason: str, - description_placeholders: Mapping[str, str] | None = None, -) -> FlowResult: - """Return the definition of an external step for the user to take.""" - return FlowResult( - type=FlowResultType.ABORT, - flow_id=flow_id, - handler=handler, - reason=reason, - description_placeholders=description_placeholders, - ) +class FlowHandler(BaseFlowHandler[FlowResult]): + """Handle a data entry flow.""" + + _flow_result = FlowResult # These can be removed if no deprecated constant are in this module anymore diff --git a/homeassistant/helpers/config_entry_flow.py b/homeassistant/helpers/config_entry_flow.py index 6cdedf98f97..b645fdb06bd 100644 --- a/homeassistant/helpers/config_entry_flow.py +++ b/homeassistant/helpers/config_entry_flow.py @@ -8,7 +8,6 @@ from typing import TYPE_CHECKING, Any, Generic, TypeVar, cast from homeassistant import config_entries from homeassistant.components import onboarding from homeassistant.core import HomeAssistant -from homeassistant.data_entry_flow import FlowResult from .typing import DiscoveryInfoType @@ -46,7 +45,7 @@ class DiscoveryFlowHandler(config_entries.ConfigFlow, Generic[_R]): async def async_step_user( self, user_input: dict[str, Any] | None = None - ) -> FlowResult: + ) -> config_entries.ConfigFlowResult: """Handle a flow initialized by the user.""" if self._async_current_entries(): return self.async_abort(reason="single_instance_allowed") @@ -57,7 +56,7 @@ class DiscoveryFlowHandler(config_entries.ConfigFlow, Generic[_R]): async def async_step_confirm( self, user_input: dict[str, Any] | None = None - ) -> FlowResult: + ) -> config_entries.ConfigFlowResult: """Confirm setup.""" if user_input is None and onboarding.async_is_onboarded(self.hass): self._set_confirm_only() @@ -87,7 +86,7 @@ class DiscoveryFlowHandler(config_entries.ConfigFlow, Generic[_R]): async def async_step_discovery( self, discovery_info: DiscoveryInfoType - ) -> FlowResult: + ) -> config_entries.ConfigFlowResult: """Handle a flow initialized by discovery.""" if self._async_in_progress() or self._async_current_entries(): return self.async_abort(reason="single_instance_allowed") @@ -98,7 +97,7 @@ class DiscoveryFlowHandler(config_entries.ConfigFlow, Generic[_R]): async def async_step_bluetooth( self, discovery_info: BluetoothServiceInfoBleak - ) -> FlowResult: + ) -> config_entries.ConfigFlowResult: """Handle a flow initialized by bluetooth discovery.""" if self._async_in_progress() or self._async_current_entries(): return self.async_abort(reason="single_instance_allowed") @@ -107,7 +106,9 @@ class DiscoveryFlowHandler(config_entries.ConfigFlow, Generic[_R]): return await self.async_step_confirm() - async def async_step_dhcp(self, discovery_info: DhcpServiceInfo) -> FlowResult: + async def async_step_dhcp( + self, discovery_info: DhcpServiceInfo + ) -> config_entries.ConfigFlowResult: """Handle a flow initialized by dhcp discovery.""" if self._async_in_progress() or self._async_current_entries(): return self.async_abort(reason="single_instance_allowed") @@ -118,7 +119,7 @@ class DiscoveryFlowHandler(config_entries.ConfigFlow, Generic[_R]): async def async_step_homekit( self, discovery_info: ZeroconfServiceInfo - ) -> FlowResult: + ) -> config_entries.ConfigFlowResult: """Handle a flow initialized by Homekit discovery.""" if self._async_in_progress() or self._async_current_entries(): return self.async_abort(reason="single_instance_allowed") @@ -127,7 +128,9 @@ class DiscoveryFlowHandler(config_entries.ConfigFlow, Generic[_R]): return await self.async_step_confirm() - async def async_step_mqtt(self, discovery_info: MqttServiceInfo) -> FlowResult: + async def async_step_mqtt( + self, discovery_info: MqttServiceInfo + ) -> config_entries.ConfigFlowResult: """Handle a flow initialized by mqtt discovery.""" if self._async_in_progress() or self._async_current_entries(): return self.async_abort(reason="single_instance_allowed") @@ -138,7 +141,7 @@ class DiscoveryFlowHandler(config_entries.ConfigFlow, Generic[_R]): async def async_step_zeroconf( self, discovery_info: ZeroconfServiceInfo - ) -> FlowResult: + ) -> config_entries.ConfigFlowResult: """Handle a flow initialized by Zeroconf discovery.""" if self._async_in_progress() or self._async_current_entries(): return self.async_abort(reason="single_instance_allowed") @@ -147,7 +150,9 @@ class DiscoveryFlowHandler(config_entries.ConfigFlow, Generic[_R]): return await self.async_step_confirm() - async def async_step_ssdp(self, discovery_info: SsdpServiceInfo) -> FlowResult: + async def async_step_ssdp( + self, discovery_info: SsdpServiceInfo + ) -> config_entries.ConfigFlowResult: """Handle a flow initialized by Ssdp discovery.""" if self._async_in_progress() or self._async_current_entries(): return self.async_abort(reason="single_instance_allowed") @@ -156,7 +161,9 @@ class DiscoveryFlowHandler(config_entries.ConfigFlow, Generic[_R]): return await self.async_step_confirm() - async def async_step_import(self, _: dict[str, Any] | None) -> FlowResult: + async def async_step_import( + self, _: dict[str, Any] | None + ) -> config_entries.ConfigFlowResult: """Handle a flow initialized by import.""" if self._async_current_entries(): return self.async_abort(reason="single_instance_allowed") @@ -205,7 +212,7 @@ class WebhookFlowHandler(config_entries.ConfigFlow): async def async_step_user( self, user_input: dict[str, Any] | None = None - ) -> FlowResult: + ) -> config_entries.ConfigFlowResult: """Handle a user initiated set up flow to create a webhook.""" if not self._allow_multiple and self._async_current_entries(): return self.async_abort(reason="single_instance_allowed") diff --git a/homeassistant/helpers/config_entry_oauth2_flow.py b/homeassistant/helpers/config_entry_oauth2_flow.py index d99cc1d4f76..337e6ca92b6 100644 --- a/homeassistant/helpers/config_entry_oauth2_flow.py +++ b/homeassistant/helpers/config_entry_oauth2_flow.py @@ -25,7 +25,6 @@ from yarl import URL from homeassistant import config_entries from homeassistant.components import http from homeassistant.core import HomeAssistant, callback -from homeassistant.data_entry_flow import FlowResult from homeassistant.loader import async_get_application_credentials from .aiohttp_client import async_get_clientsession @@ -253,7 +252,7 @@ class AbstractOAuth2FlowHandler(config_entries.ConfigFlow, metaclass=ABCMeta): async def async_step_pick_implementation( self, user_input: dict | None = None - ) -> FlowResult: + ) -> config_entries.ConfigFlowResult: """Handle a flow start.""" implementations = await async_get_implementations(self.hass, self.DOMAIN) @@ -286,7 +285,7 @@ class AbstractOAuth2FlowHandler(config_entries.ConfigFlow, metaclass=ABCMeta): async def async_step_auth( self, user_input: dict[str, Any] | None = None - ) -> FlowResult: + ) -> config_entries.ConfigFlowResult: """Create an entry for auth.""" # Flow has been triggered by external data if user_input is not None: @@ -314,7 +313,7 @@ class AbstractOAuth2FlowHandler(config_entries.ConfigFlow, metaclass=ABCMeta): async def async_step_creation( self, user_input: dict[str, Any] | None = None - ) -> FlowResult: + ) -> config_entries.ConfigFlowResult: """Create config entry from external data.""" _LOGGER.debug("Creating config entry from external data") @@ -353,14 +352,18 @@ class AbstractOAuth2FlowHandler(config_entries.ConfigFlow, metaclass=ABCMeta): {"auth_implementation": self.flow_impl.domain, "token": token} ) - async def async_step_authorize_rejected(self, data: None = None) -> FlowResult: + async def async_step_authorize_rejected( + self, data: None = None + ) -> config_entries.ConfigFlowResult: """Step to handle flow rejection.""" return self.async_abort( reason="user_rejected_authorize", description_placeholders={"error": self.external_data["error"]}, ) - async def async_oauth_create_entry(self, data: dict) -> FlowResult: + async def async_oauth_create_entry( + self, data: dict + ) -> config_entries.ConfigFlowResult: """Create an entry for the flow. Ok to override if you want to fetch extra info or even add another step. @@ -369,7 +372,7 @@ class AbstractOAuth2FlowHandler(config_entries.ConfigFlow, metaclass=ABCMeta): async def async_step_user( self, user_input: dict[str, Any] | None = None - ) -> FlowResult: + ) -> config_entries.ConfigFlowResult: """Handle a flow start.""" return await self.async_step_pick_implementation(user_input) diff --git a/homeassistant/helpers/data_entry_flow.py b/homeassistant/helpers/data_entry_flow.py index 695fbbf7633..6a6e48caa7e 100644 --- a/homeassistant/helpers/data_entry_flow.py +++ b/homeassistant/helpers/data_entry_flow.py @@ -18,7 +18,7 @@ from . import config_validation as cv class _BaseFlowManagerView(HomeAssistantView): """Foundation for flow manager views.""" - def __init__(self, flow_mgr: data_entry_flow.FlowManager) -> None: + def __init__(self, flow_mgr: data_entry_flow.BaseFlowManager) -> None: """Initialize the flow manager index view.""" self._flow_mgr = flow_mgr diff --git a/homeassistant/helpers/discovery_flow.py b/homeassistant/helpers/discovery_flow.py index c4698de1f52..a24e87325ae 100644 --- a/homeassistant/helpers/discovery_flow.py +++ b/homeassistant/helpers/discovery_flow.py @@ -4,9 +4,9 @@ from __future__ import annotations from collections.abc import Coroutine from typing import Any, NamedTuple +from homeassistant.config_entries import ConfigFlowResult from homeassistant.const import EVENT_HOMEASSISTANT_STARTED from homeassistant.core import CoreState, Event, HomeAssistant, callback -from homeassistant.data_entry_flow import FlowResult from homeassistant.loader import bind_hass from homeassistant.util.async_ import gather_with_limited_concurrency @@ -40,7 +40,7 @@ def async_create_flow( @callback def _async_init_flow( hass: HomeAssistant, domain: str, context: dict[str, Any], data: Any -) -> Coroutine[None, None, FlowResult] | None: +) -> Coroutine[None, None, ConfigFlowResult] | None: """Create a discovery flow.""" # Avoid spawning flows that have the same initial discovery data # as ones in progress as it may cause additional device probing diff --git a/homeassistant/helpers/schema_config_entry_flow.py b/homeassistant/helpers/schema_config_entry_flow.py index 2bbad0ed63a..d5563c995ff 100644 --- a/homeassistant/helpers/schema_config_entry_flow.py +++ b/homeassistant/helpers/schema_config_entry_flow.py @@ -10,9 +10,15 @@ from typing import Any, cast import voluptuous as vol -from homeassistant import config_entries +from homeassistant.config_entries import ( + ConfigEntry, + ConfigFlow, + ConfigFlowResult, + OptionsFlow, + OptionsFlowWithConfigEntry, +) from homeassistant.core import HomeAssistant, callback, split_entity_id -from homeassistant.data_entry_flow import FlowResult, UnknownHandler +from homeassistant.data_entry_flow import UnknownHandler from . import entity_registry as er, selector from .typing import UNDEFINED, UndefinedType @@ -126,7 +132,7 @@ class SchemaCommonFlowHandler: async def async_step( self, step_id: str, user_input: dict[str, Any] | None = None - ) -> FlowResult: + ) -> ConfigFlowResult: """Handle a step.""" if isinstance(self._flow[step_id], SchemaFlowFormStep): return await self._async_form_step(step_id, user_input) @@ -141,7 +147,7 @@ class SchemaCommonFlowHandler: async def _async_form_step( self, step_id: str, user_input: dict[str, Any] | None = None - ) -> FlowResult: + ) -> ConfigFlowResult: """Handle a form step.""" form_step: SchemaFlowFormStep = cast(SchemaFlowFormStep, self._flow[step_id]) @@ -204,7 +210,7 @@ class SchemaCommonFlowHandler: async def _show_next_step_or_create_entry( self, form_step: SchemaFlowFormStep - ) -> FlowResult: + ) -> ConfigFlowResult: next_step_id_or_end_flow: str | None if callable(form_step.next_step): @@ -222,7 +228,7 @@ class SchemaCommonFlowHandler: next_step_id: str, error: SchemaFlowError | None = None, user_input: dict[str, Any] | None = None, - ) -> FlowResult: + ) -> ConfigFlowResult: """Show form for next step.""" if isinstance(self._flow[next_step_id], SchemaFlowMenuStep): menu_step = cast(SchemaFlowMenuStep, self._flow[next_step_id]) @@ -271,7 +277,7 @@ class SchemaCommonFlowHandler: async def _async_menu_step( self, step_id: str, user_input: dict[str, Any] | None = None - ) -> FlowResult: + ) -> ConfigFlowResult: """Handle a menu step.""" menu_step: SchemaFlowMenuStep = cast(SchemaFlowMenuStep, self._flow[step_id]) return self._handler.async_show_menu( @@ -280,7 +286,7 @@ class SchemaCommonFlowHandler: ) -class SchemaConfigFlowHandler(config_entries.ConfigFlow, ABC): +class SchemaConfigFlowHandler(ConfigFlow, ABC): """Handle a schema based config flow.""" config_flow: Mapping[str, SchemaFlowStep] @@ -294,8 +300,8 @@ class SchemaConfigFlowHandler(config_entries.ConfigFlow, ABC): @callback def _async_get_options_flow( - config_entry: config_entries.ConfigEntry, - ) -> config_entries.OptionsFlow: + config_entry: ConfigEntry, + ) -> OptionsFlow: """Get the options flow for this handler.""" if cls.options_flow is None: raise UnknownHandler @@ -324,9 +330,7 @@ class SchemaConfigFlowHandler(config_entries.ConfigFlow, ABC): @classmethod @callback - def async_supports_options_flow( - cls, config_entry: config_entries.ConfigEntry - ) -> bool: + def async_supports_options_flow(cls, config_entry: ConfigEntry) -> bool: """Return options flow support for this handler.""" return cls.options_flow is not None @@ -335,13 +339,13 @@ class SchemaConfigFlowHandler(config_entries.ConfigFlow, ABC): step_id: str, ) -> Callable[ [SchemaConfigFlowHandler, dict[str, Any] | None], - Coroutine[Any, Any, FlowResult], + Coroutine[Any, Any, ConfigFlowResult], ]: """Generate a step handler.""" async def _async_step( self: SchemaConfigFlowHandler, user_input: dict[str, Any] | None = None - ) -> FlowResult: + ) -> ConfigFlowResult: """Handle a config flow step.""" # pylint: disable-next=protected-access result = await self._common_handler.async_step(step_id, user_input) @@ -382,7 +386,7 @@ class SchemaConfigFlowHandler(config_entries.ConfigFlow, ABC): self, data: Mapping[str, Any], **kwargs: Any, - ) -> FlowResult: + ) -> ConfigFlowResult: """Finish config flow and create a config entry.""" self.async_config_flow_finished(data) return super().async_create_entry( @@ -390,12 +394,12 @@ class SchemaConfigFlowHandler(config_entries.ConfigFlow, ABC): ) -class SchemaOptionsFlowHandler(config_entries.OptionsFlowWithConfigEntry): +class SchemaOptionsFlowHandler(OptionsFlowWithConfigEntry): """Handle a schema based options flow.""" def __init__( self, - config_entry: config_entries.ConfigEntry, + config_entry: ConfigEntry, options_flow: Mapping[str, SchemaFlowStep], async_options_flow_finished: Callable[[HomeAssistant, Mapping[str, Any]], None] | None = None, @@ -430,13 +434,13 @@ class SchemaOptionsFlowHandler(config_entries.OptionsFlowWithConfigEntry): step_id: str, ) -> Callable[ [SchemaConfigFlowHandler, dict[str, Any] | None], - Coroutine[Any, Any, FlowResult], + Coroutine[Any, Any, ConfigFlowResult], ]: """Generate a step handler.""" async def _async_step( self: SchemaConfigFlowHandler, user_input: dict[str, Any] | None = None - ) -> FlowResult: + ) -> ConfigFlowResult: """Handle an options flow step.""" # pylint: disable-next=protected-access result = await self._common_handler.async_step(step_id, user_input) @@ -449,7 +453,7 @@ class SchemaOptionsFlowHandler(config_entries.OptionsFlowWithConfigEntry): self, data: Mapping[str, Any], **kwargs: Any, - ) -> FlowResult: + ) -> ConfigFlowResult: """Finish config flow and create a config entry.""" if self._async_options_flow_finished: self._async_options_flow_finished(self.hass, data) diff --git a/pylint/plugins/hass_enforce_type_hints.py b/pylint/plugins/hass_enforce_type_hints.py index b2620dd3e1e..602bd8a443d 100644 --- a/pylint/plugins/hass_enforce_type_hints.py +++ b/pylint/plugins/hass_enforce_type_hints.py @@ -55,11 +55,12 @@ class TypeHintMatch: ) -@dataclass +@dataclass(kw_only=True) class ClassTypeHintMatch: """Class for pattern matching.""" base_class: str + exclude_base_classes: set[str] | None = None matches: list[TypeHintMatch] @@ -481,6 +482,7 @@ _CLASS_MATCH: dict[str, list[ClassTypeHintMatch]] = { "config_flow": [ ClassTypeHintMatch( base_class="FlowHandler", + exclude_base_classes={"ConfigEntryBaseFlow"}, matches=[ TypeHintMatch( function_name="async_step_*", @@ -492,6 +494,11 @@ _CLASS_MATCH: dict[str, list[ClassTypeHintMatch]] = { ClassTypeHintMatch( base_class="ConfigFlow", matches=[ + TypeHintMatch( + function_name="async_step123_*", + arg_types={}, + return_type=["ConfigFlowResult", "FlowResult"], + ), TypeHintMatch( function_name="async_get_options_flow", arg_types={ @@ -504,56 +511,66 @@ _CLASS_MATCH: dict[str, list[ClassTypeHintMatch]] = { arg_types={ 1: "DhcpServiceInfo", }, - return_type="FlowResult", + return_type=["ConfigFlowResult", "FlowResult"], ), TypeHintMatch( function_name="async_step_hassio", arg_types={ 1: "HassioServiceInfo", }, - return_type="FlowResult", + return_type=["ConfigFlowResult", "FlowResult"], ), TypeHintMatch( function_name="async_step_homekit", arg_types={ 1: "ZeroconfServiceInfo", }, - return_type="FlowResult", + return_type=["ConfigFlowResult", "FlowResult"], ), TypeHintMatch( function_name="async_step_mqtt", arg_types={ 1: "MqttServiceInfo", }, - return_type="FlowResult", + return_type=["ConfigFlowResult", "FlowResult"], ), TypeHintMatch( function_name="async_step_reauth", arg_types={ 1: "Mapping[str, Any]", }, - return_type="FlowResult", + return_type=["ConfigFlowResult", "FlowResult"], ), TypeHintMatch( function_name="async_step_ssdp", arg_types={ 1: "SsdpServiceInfo", }, - return_type="FlowResult", + return_type=["ConfigFlowResult", "FlowResult"], ), TypeHintMatch( function_name="async_step_usb", arg_types={ 1: "UsbServiceInfo", }, - return_type="FlowResult", + return_type=["ConfigFlowResult", "FlowResult"], ), TypeHintMatch( function_name="async_step_zeroconf", arg_types={ 1: "ZeroconfServiceInfo", }, - return_type="FlowResult", + return_type=["ConfigFlowResult", "FlowResult"], + ), + ], + ), + ClassTypeHintMatch( + base_class="OptionsFlow", + matches=[ + TypeHintMatch( + function_name="async_step_*", + arg_types={}, + return_type=["ConfigFlowResult", "FlowResult"], ), ], ), @@ -3126,11 +3143,19 @@ class HassTypeHintChecker(BaseChecker): ancestor: nodes.ClassDef checked_class_methods: set[str] = set() ancestors = list(node.ancestors()) # cache result for inside loop - for class_matches in self._class_matchers: + for class_matcher in self._class_matchers: + skip_matcher = False + if exclude_base_classes := class_matcher.exclude_base_classes: + for ancestor in ancestors: + if ancestor.name in exclude_base_classes: + skip_matcher = True + break + if skip_matcher: + continue for ancestor in ancestors: - if ancestor.name == class_matches.base_class: + if ancestor.name == class_matcher.base_class: self._visit_class_functions( - node, class_matches.matches, checked_class_methods + node, class_matcher.matches, checked_class_methods ) def _visit_class_functions( diff --git a/script/scaffold/templates/config_flow/integration/config_flow.py b/script/scaffold/templates/config_flow/integration/config_flow.py index caef6c2e729..f68059584f7 100644 --- a/script/scaffold/templates/config_flow/integration/config_flow.py +++ b/script/scaffold/templates/config_flow/integration/config_flow.py @@ -6,10 +6,9 @@ from typing import Any import voluptuous as vol -from homeassistant import config_entries +from homeassistant.config_entries import ConfigFlow, ConfigFlowResult from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_USERNAME from homeassistant.core import HomeAssistant -from homeassistant.data_entry_flow import FlowResult from homeassistant.exceptions import HomeAssistantError from .const import DOMAIN @@ -68,14 +67,14 @@ async def validate_input(hass: HomeAssistant, data: dict[str, Any]) -> dict[str, return {"title": "Name of the device"} -class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): +class ConfigFlow(ConfigFlow, domain=DOMAIN): """Handle a config flow for NEW_NAME.""" VERSION = 1 async def async_step_user( self, user_input: dict[str, Any] | None = None - ) -> FlowResult: + ) -> ConfigFlowResult: """Handle the initial step.""" errors: dict[str, str] = {} if user_input is not None: diff --git a/tests/components/cloud/test_repairs.py b/tests/components/cloud/test_repairs.py index 0e662c30ee7..9380cec2ebd 100644 --- a/tests/components/cloud/test_repairs.py +++ b/tests/components/cloud/test_repairs.py @@ -147,13 +147,11 @@ async def test_legacy_subscription_repair_flow( flow_id = data["flow_id"] assert data == { - "version": 1, "type": "create_entry", "flow_id": flow_id, "handler": DOMAIN, "description": None, "description_placeholders": None, - "minor_version": 1, } assert not issue_registry.async_get_issue( diff --git a/tests/components/config/test_config_entries.py b/tests/components/config/test_config_entries.py index 844b4bdb3b4..6573a83b061 100644 --- a/tests/components/config/test_config_entries.py +++ b/tests/components/config/test_config_entries.py @@ -941,10 +941,8 @@ async def test_two_step_options_flow(hass: HomeAssistant, client) -> None: "handler": "test1", "type": "create_entry", "title": "Enable disable", - "version": 1, "description": None, "description_placeholders": None, - "minor_version": 1, } diff --git a/tests/components/hassio/test_repairs.py b/tests/components/hassio/test_repairs.py index 5dd73a21615..6a9287e0331 100644 --- a/tests/components/hassio/test_repairs.py +++ b/tests/components/hassio/test_repairs.py @@ -94,13 +94,11 @@ async def test_supervisor_issue_repair_flow( flow_id = data["flow_id"] assert data == { - "version": 1, "type": "create_entry", "flow_id": flow_id, "handler": "hassio", "description": None, "description_placeholders": None, - "minor_version": 1, } assert not issue_registry.async_get_issue(domain="hassio", issue_id="1234") @@ -190,13 +188,11 @@ async def test_supervisor_issue_repair_flow_with_multiple_suggestions( flow_id = data["flow_id"] assert data == { - "version": 1, "type": "create_entry", "flow_id": flow_id, "handler": "hassio", "description": None, "description_placeholders": None, - "minor_version": 1, } assert not issue_registry.async_get_issue(domain="hassio", issue_id="1234") @@ -305,13 +301,11 @@ async def test_supervisor_issue_repair_flow_with_multiple_suggestions_and_confir flow_id = data["flow_id"] assert data == { - "version": 1, "type": "create_entry", "flow_id": flow_id, "handler": "hassio", "description": None, "description_placeholders": None, - "minor_version": 1, } assert not issue_registry.async_get_issue(domain="hassio", issue_id="1234") @@ -386,13 +380,11 @@ async def test_supervisor_issue_repair_flow_skip_confirmation( flow_id = data["flow_id"] assert data == { - "version": 1, "type": "create_entry", "flow_id": flow_id, "handler": "hassio", "description": None, "description_placeholders": None, - "minor_version": 1, } assert not issue_registry.async_get_issue(domain="hassio", issue_id="1234") @@ -486,13 +478,11 @@ async def test_mount_failed_repair_flow( flow_id = data["flow_id"] assert data == { - "version": 1, "type": "create_entry", "flow_id": flow_id, "handler": "hassio", "description": None, "description_placeholders": None, - "minor_version": 1, } assert not issue_registry.async_get_issue(domain="hassio", issue_id="1234") @@ -598,13 +588,11 @@ async def test_supervisor_issue_docker_config_repair_flow( flow_id = data["flow_id"] assert data == { - "version": 1, "type": "create_entry", "flow_id": flow_id, "handler": "hassio", "description": None, "description_placeholders": None, - "minor_version": 1, } assert not issue_registry.async_get_issue(domain="hassio", issue_id="1234") diff --git a/tests/components/kitchen_sink/test_init.py b/tests/components/kitchen_sink/test_init.py index b3f303fcfe1..fb81c87008e 100644 --- a/tests/components/kitchen_sink/test_init.py +++ b/tests/components/kitchen_sink/test_init.py @@ -244,9 +244,7 @@ async def test_issues_created( "description_placeholders": None, "flow_id": flow_id, "handler": DOMAIN, - "minor_version": 1, "type": "create_entry", - "version": 1, } await ws_client.send_json({"id": 4, "type": "repairs/list_issues"}) diff --git a/tests/components/repairs/test_websocket_api.py b/tests/components/repairs/test_websocket_api.py index 0cf6b22dc0c..ef08095ca79 100644 --- a/tests/components/repairs/test_websocket_api.py +++ b/tests/components/repairs/test_websocket_api.py @@ -338,9 +338,7 @@ async def test_fix_issue( "description_placeholders": None, "flow_id": flow_id, "handler": domain, - "minor_version": 1, "type": "create_entry", - "version": 1, } await ws_client.send_json({"id": 4, "type": "repairs/list_issues"}) diff --git a/tests/helpers/test_discovery_flow.py b/tests/helpers/test_discovery_flow.py index 0b3386f8e04..7dcf6256a59 100644 --- a/tests/helpers/test_discovery_flow.py +++ b/tests/helpers/test_discovery_flow.py @@ -63,7 +63,7 @@ async def test_async_create_flow_checks_existing_flows_after_startup( """Test existing flows prevent an identical ones from being after startup.""" hass.bus.async_fire(EVENT_HOMEASSISTANT_STARTED) with patch( - "homeassistant.data_entry_flow.FlowManager.async_has_matching_flow", + "homeassistant.data_entry_flow.BaseFlowManager.async_has_matching_flow", return_value=True, ): discovery_flow.async_create_flow( diff --git a/tests/helpers/test_schema_config_entry_flow.py b/tests/helpers/test_schema_config_entry_flow.py index 58f6a261aef..6778a168dd7 100644 --- a/tests/helpers/test_schema_config_entry_flow.py +++ b/tests/helpers/test_schema_config_entry_flow.py @@ -45,7 +45,7 @@ def manager_fixture(): handlers = Registry() entries = [] - class FlowManager(data_entry_flow.FlowManager): + class FlowManager(data_entry_flow.BaseFlowManager): """Test flow manager.""" async def async_create_flow(self, handler_key, *, context, data): @@ -105,7 +105,7 @@ async def test_name(hass: HomeAssistant, entity_registry: er.EntityRegistry) -> @pytest.mark.parametrize("marker", (vol.Required, vol.Optional)) async def test_config_flow_advanced_option( - hass: HomeAssistant, manager: data_entry_flow.FlowManager, marker + hass: HomeAssistant, manager: data_entry_flow.BaseFlowManager, marker ) -> None: """Test handling of advanced options in config flow.""" manager.hass = hass @@ -200,7 +200,7 @@ async def test_config_flow_advanced_option( @pytest.mark.parametrize("marker", (vol.Required, vol.Optional)) async def test_options_flow_advanced_option( - hass: HomeAssistant, manager: data_entry_flow.FlowManager, marker + hass: HomeAssistant, manager: data_entry_flow.BaseFlowManager, marker ) -> None: """Test handling of advanced options in options flow.""" manager.hass = hass @@ -475,7 +475,7 @@ async def test_next_step_function(hass: HomeAssistant) -> None: async def test_suggested_values( - hass: HomeAssistant, manager: data_entry_flow.FlowManager + hass: HomeAssistant, manager: data_entry_flow.BaseFlowManager ) -> None: """Test suggested_values handling in SchemaFlowFormStep.""" manager.hass = hass @@ -667,7 +667,7 @@ async def test_options_flow_state(hass: HomeAssistant) -> None: async def test_options_flow_omit_optional_keys( - hass: HomeAssistant, manager: data_entry_flow.FlowManager + hass: HomeAssistant, manager: data_entry_flow.BaseFlowManager ) -> None: """Test handling of advanced options in options flow.""" manager.hass = hass diff --git a/tests/pylint/test_enforce_type_hints.py b/tests/pylint/test_enforce_type_hints.py index d23d5a849dd..2a03343cb82 100644 --- a/tests/pylint/test_enforce_type_hints.py +++ b/tests/pylint/test_enforce_type_hints.py @@ -346,7 +346,7 @@ def test_invalid_config_flow_step( pylint.testutils.MessageTest( msg_id="hass-return-type", node=func_node, - args=("FlowResult", "async_step_zeroconf"), + args=(["ConfigFlowResult", "FlowResult"], "async_step_zeroconf"), line=11, col_offset=4, end_line=11, @@ -374,7 +374,7 @@ def test_valid_config_flow_step( async def async_step_zeroconf( self, device_config: ZeroconfServiceInfo - ) -> FlowResult: + ) -> ConfigFlowResult: pass """, "homeassistant.components.pylint_test.config_flow", diff --git a/tests/test_data_entry_flow.py b/tests/test_data_entry_flow.py index d39c8faccef..96bd45d4e36 100644 --- a/tests/test_data_entry_flow.py +++ b/tests/test_data_entry_flow.py @@ -24,9 +24,11 @@ def manager(): handlers = Registry() entries = [] - class FlowManager(data_entry_flow.FlowManager): + class FlowManager(data_entry_flow.BaseFlowManager): """Test flow manager.""" + _flow_result = data_entry_flow.FlowResult + async def async_create_flow(self, handler_key, *, context, data): """Test create flow.""" handler = handlers.get(handler_key) @@ -79,7 +81,7 @@ async def test_configure_reuses_handler_instance(manager) -> None: assert len(manager.mock_created_entries) == 0 -async def test_configure_two_steps(manager: data_entry_flow.FlowManager) -> None: +async def test_configure_two_steps(manager: data_entry_flow.BaseFlowManager) -> None: """Test that we reuse instances.""" @manager.mock_reg_handler("test") @@ -211,7 +213,6 @@ async def test_create_saves_data(manager) -> None: assert len(manager.mock_created_entries) == 1 entry = manager.mock_created_entries[0] - assert entry["version"] == 5 assert entry["handler"] == "test" assert entry["title"] == "Test Title" assert entry["data"] == "Test Data" @@ -237,7 +238,6 @@ async def test_discovery_init_flow(manager) -> None: assert len(manager.mock_created_entries) == 1 entry = manager.mock_created_entries[0] - assert entry["version"] == 5 assert entry["handler"] == "test" assert entry["title"] == "hello" assert entry["data"] == data @@ -258,7 +258,7 @@ async def test_finish_callback_change_result_type(hass: HomeAssistant) -> None: step_id="init", data_schema=vol.Schema({"count": int}) ) - class FlowManager(data_entry_flow.FlowManager): + class FlowManager(data_entry_flow.BaseFlowManager): async def async_create_flow(self, handler_name, *, context, data): """Create a test flow.""" return TestFlow() @@ -775,7 +775,7 @@ async def test_async_get_unknown_flow(manager) -> None: async def test_async_has_matching_flow( - hass: HomeAssistant, manager: data_entry_flow.FlowManager + hass: HomeAssistant, manager: data_entry_flow.BaseFlowManager ) -> None: """Test we can check for matching flows.""" manager.hass = hass @@ -951,7 +951,7 @@ async def test_show_menu(hass: HomeAssistant, manager, menu_options) -> None: async def test_find_flows_by_init_data_type( - manager: data_entry_flow.FlowManager, + manager: data_entry_flow.BaseFlowManager, ) -> None: """Test we can find flows by init data type."""