Make FlowResult a generic type (#111952)

This commit is contained in:
Erik Montnemery 2024-03-07 12:41:14 +01:00 committed by GitHub
parent 008e025d5c
commit 82efb3d35b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 95 additions and 80 deletions

View File

@ -19,13 +19,13 @@ from homeassistant.core import (
HomeAssistant, HomeAssistant,
callback, callback,
) )
from homeassistant.data_entry_flow import FlowResult
from homeassistant.helpers.event import async_track_point_in_utc_time from homeassistant.helpers.event import async_track_point_in_utc_time
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
from . import auth_store, jwt_wrapper, models from . import auth_store, jwt_wrapper, models
from .const import ACCESS_TOKEN_EXPIRATION, GROUP_ID_ADMIN, REFRESH_TOKEN_EXPIRATION from .const import ACCESS_TOKEN_EXPIRATION, GROUP_ID_ADMIN, REFRESH_TOKEN_EXPIRATION
from .mfa_modules import MultiFactorAuthModule, auth_mfa_module_from_config from .mfa_modules import MultiFactorAuthModule, auth_mfa_module_from_config
from .models import AuthFlowResult
from .providers import AuthProvider, LoginFlow, auth_provider_from_config from .providers import AuthProvider, LoginFlow, auth_provider_from_config
EVENT_USER_ADDED = "user_added" EVENT_USER_ADDED = "user_added"
@ -88,9 +88,13 @@ async def auth_manager_from_config(
return manager return manager
class AuthManagerFlowManager(data_entry_flow.FlowManager): class AuthManagerFlowManager(
data_entry_flow.FlowManager[AuthFlowResult, tuple[str, str]]
):
"""Manage authentication flows.""" """Manage authentication flows."""
_flow_result = AuthFlowResult
def __init__(self, hass: HomeAssistant, auth_manager: AuthManager) -> None: def __init__(self, hass: HomeAssistant, auth_manager: AuthManager) -> None:
"""Init auth manager flows.""" """Init auth manager flows."""
super().__init__(hass) super().__init__(hass)
@ -98,11 +102,11 @@ class AuthManagerFlowManager(data_entry_flow.FlowManager):
async def async_create_flow( async def async_create_flow(
self, self,
handler_key: str, handler_key: tuple[str, str],
*, *,
context: dict[str, Any] | None = None, context: dict[str, Any] | None = None,
data: dict[str, Any] | None = None, data: dict[str, Any] | None = None,
) -> data_entry_flow.FlowHandler: ) -> LoginFlow:
"""Create a login flow.""" """Create a login flow."""
auth_provider = self.auth_manager.get_auth_provider(*handler_key) auth_provider = self.auth_manager.get_auth_provider(*handler_key)
if not auth_provider: if not auth_provider:
@ -110,8 +114,10 @@ class AuthManagerFlowManager(data_entry_flow.FlowManager):
return await auth_provider.async_login_flow(context) return await auth_provider.async_login_flow(context)
async def async_finish_flow( async def async_finish_flow(
self, flow: data_entry_flow.FlowHandler, result: FlowResult self,
) -> FlowResult: flow: data_entry_flow.FlowHandler[AuthFlowResult, tuple[str, str]],
result: AuthFlowResult,
) -> AuthFlowResult:
"""Return a user as result of login flow.""" """Return a user as result of login flow."""
flow = cast(LoginFlow, flow) flow = cast(LoginFlow, flow)

View File

@ -11,6 +11,7 @@ from attr import Attribute
from attr.setters import validate from attr.setters import validate
from homeassistant.const import __version__ from homeassistant.const import __version__
from homeassistant.data_entry_flow import FlowResult
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
from . import permissions as perm_mdl from . import permissions as perm_mdl
@ -26,6 +27,8 @@ TOKEN_TYPE_NORMAL = "normal"
TOKEN_TYPE_SYSTEM = "system" TOKEN_TYPE_SYSTEM = "system"
TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN = "long_lived_access_token" TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN = "long_lived_access_token"
AuthFlowResult = FlowResult[tuple[str, str]]
@attr.s(slots=True) @attr.s(slots=True)
class Group: class Group:

View File

@ -13,14 +13,13 @@ from voluptuous.humanize import humanize_error
from homeassistant import data_entry_flow, requirements from homeassistant import data_entry_flow, requirements
from homeassistant.const import CONF_ID, CONF_NAME, CONF_TYPE from homeassistant.const import CONF_ID, CONF_NAME, CONF_TYPE
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.data_entry_flow import FlowResult
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
from homeassistant.util.decorator import Registry from homeassistant.util.decorator import Registry
from ..auth_store import AuthStore from ..auth_store import AuthStore
from ..const import MFA_SESSION_EXPIRATION from ..const import MFA_SESSION_EXPIRATION
from ..models import Credentials, RefreshToken, User, UserMeta from ..models import AuthFlowResult, Credentials, RefreshToken, User, UserMeta
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
DATA_REQS = "auth_prov_reqs_processed" DATA_REQS = "auth_prov_reqs_processed"
@ -181,9 +180,11 @@ async def load_auth_provider_module(
return module return module
class LoginFlow(data_entry_flow.FlowHandler): class LoginFlow(data_entry_flow.FlowHandler[AuthFlowResult, tuple[str, str]]):
"""Handler for the login flow.""" """Handler for the login flow."""
_flow_result = AuthFlowResult
def __init__(self, auth_provider: AuthProvider) -> None: def __init__(self, auth_provider: AuthProvider) -> None:
"""Initialize the login flow.""" """Initialize the login flow."""
self._auth_provider = auth_provider self._auth_provider = auth_provider
@ -197,7 +198,7 @@ class LoginFlow(data_entry_flow.FlowHandler):
async def async_step_init( async def async_step_init(
self, user_input: dict[str, str] | None = None self, user_input: dict[str, str] | None = None
) -> FlowResult: ) -> AuthFlowResult:
"""Handle the first step of login flow. """Handle the first step of login flow.
Return self.async_show_form(step_id='init') if user_input is None. Return self.async_show_form(step_id='init') if user_input is None.
@ -207,7 +208,7 @@ class LoginFlow(data_entry_flow.FlowHandler):
async def async_step_select_mfa_module( async def async_step_select_mfa_module(
self, user_input: dict[str, str] | None = None self, user_input: dict[str, str] | None = None
) -> FlowResult: ) -> AuthFlowResult:
"""Handle the step of select mfa module.""" """Handle the step of select mfa module."""
errors = {} errors = {}
@ -232,7 +233,7 @@ class LoginFlow(data_entry_flow.FlowHandler):
async def async_step_mfa( async def async_step_mfa(
self, user_input: dict[str, str] | None = None self, user_input: dict[str, str] | None = None
) -> FlowResult: ) -> AuthFlowResult:
"""Handle the step of mfa validation.""" """Handle the step of mfa validation."""
assert self.credential assert self.credential
assert self.user assert self.user
@ -282,6 +283,6 @@ class LoginFlow(data_entry_flow.FlowHandler):
errors=errors, errors=errors,
) )
async def async_finish(self, flow_result: Any) -> FlowResult: async def async_finish(self, flow_result: Any) -> AuthFlowResult:
"""Handle the pass of login flow.""" """Handle the pass of login flow."""
return self.async_create_entry(data=flow_result) return self.async_create_entry(data=flow_result)

View File

@ -10,10 +10,9 @@ from typing import Any, cast
import voluptuous as vol import voluptuous as vol
from homeassistant.const import CONF_COMMAND from homeassistant.const import CONF_COMMAND
from homeassistant.data_entry_flow import FlowResult
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from ..models import Credentials, UserMeta from ..models import AuthFlowResult, Credentials, UserMeta
from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow
CONF_ARGS = "args" CONF_ARGS = "args"
@ -138,7 +137,7 @@ class CommandLineLoginFlow(LoginFlow):
async def async_step_init( async def async_step_init(
self, user_input: dict[str, str] | None = None self, user_input: dict[str, str] | None = None
) -> FlowResult: ) -> AuthFlowResult:
"""Handle the step of the form.""" """Handle the step of the form."""
errors = {} errors = {}

View File

@ -12,11 +12,10 @@ import voluptuous as vol
from homeassistant.const import CONF_ID from homeassistant.const import CONF_ID
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.data_entry_flow import FlowResult
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers.storage import Store from homeassistant.helpers.storage import Store
from ..models import Credentials, UserMeta from ..models import AuthFlowResult, Credentials, UserMeta
from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow
STORAGE_VERSION = 1 STORAGE_VERSION = 1
@ -321,7 +320,7 @@ class HassLoginFlow(LoginFlow):
async def async_step_init( async def async_step_init(
self, user_input: dict[str, str] | None = None self, user_input: dict[str, str] | None = None
) -> FlowResult: ) -> AuthFlowResult:
"""Handle the step of the form.""" """Handle the step of the form."""
errors = {} errors = {}

View File

@ -8,10 +8,9 @@ from typing import Any, cast
import voluptuous as vol import voluptuous as vol
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.data_entry_flow import FlowResult
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
from ..models import Credentials, UserMeta from ..models import AuthFlowResult, Credentials, UserMeta
from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow
USER_SCHEMA = vol.Schema( USER_SCHEMA = vol.Schema(
@ -98,7 +97,7 @@ class ExampleLoginFlow(LoginFlow):
async def async_step_init( async def async_step_init(
self, user_input: dict[str, str] | None = None self, user_input: dict[str, str] | None = None
) -> FlowResult: ) -> AuthFlowResult:
"""Handle the step of the form.""" """Handle the step of the form."""
errors = None errors = None

View File

@ -11,12 +11,11 @@ from typing import Any, cast
import voluptuous as vol import voluptuous as vol
from homeassistant.core import async_get_hass, callback from homeassistant.core import async_get_hass, callback
from homeassistant.data_entry_flow import FlowResult
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.issue_registry import IssueSeverity, async_create_issue from homeassistant.helpers.issue_registry import IssueSeverity, async_create_issue
from ..models import Credentials, UserMeta from ..models import AuthFlowResult, Credentials, UserMeta
from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow
AUTH_PROVIDER_TYPE = "legacy_api_password" AUTH_PROVIDER_TYPE = "legacy_api_password"
@ -101,7 +100,7 @@ class LegacyLoginFlow(LoginFlow):
async def async_step_init( async def async_step_init(
self, user_input: dict[str, str] | None = None self, user_input: dict[str, str] | None = None
) -> FlowResult: ) -> AuthFlowResult:
"""Handle the step of the form.""" """Handle the step of the form."""
errors = {} errors = {}

View File

@ -19,13 +19,12 @@ from typing import Any, cast
import voluptuous as vol import voluptuous as vol
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.data_entry_flow import FlowResult
from homeassistant.exceptions import HomeAssistantError from homeassistant.exceptions import HomeAssistantError
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.network import is_cloud_connection from homeassistant.helpers.network import is_cloud_connection
from .. import InvalidAuthError from .. import InvalidAuthError
from ..models import Credentials, RefreshToken, UserMeta from ..models import AuthFlowResult, Credentials, RefreshToken, UserMeta
from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow from . import AUTH_PROVIDER_SCHEMA, AUTH_PROVIDERS, AuthProvider, LoginFlow
IPAddress = IPv4Address | IPv6Address IPAddress = IPv4Address | IPv6Address
@ -226,7 +225,7 @@ class TrustedNetworksLoginFlow(LoginFlow):
async def async_step_init( async def async_step_init(
self, user_input: dict[str, str] | None = None self, user_input: dict[str, str] | None = None
) -> FlowResult: ) -> AuthFlowResult:
"""Handle the step of the form.""" """Handle the step of the form."""
try: try:
cast( cast(

View File

@ -79,7 +79,7 @@ import voluptuous_serialize
from homeassistant import data_entry_flow from homeassistant import data_entry_flow
from homeassistant.auth import AuthManagerFlowManager, InvalidAuthError from homeassistant.auth import AuthManagerFlowManager, InvalidAuthError
from homeassistant.auth.models import Credentials from homeassistant.auth.models import AuthFlowResult, Credentials
from homeassistant.components import onboarding from homeassistant.components import onboarding
from homeassistant.components.http.auth import async_user_not_allowed_do_auth from homeassistant.components.http.auth import async_user_not_allowed_do_auth
from homeassistant.components.http.ban import ( from homeassistant.components.http.ban import (
@ -197,8 +197,8 @@ class AuthProvidersView(HomeAssistantView):
def _prepare_result_json( def _prepare_result_json(
result: data_entry_flow.FlowResult, result: AuthFlowResult,
) -> data_entry_flow.FlowResult: ) -> AuthFlowResult:
"""Convert result to JSON.""" """Convert result to JSON."""
if result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY: if result["type"] == data_entry_flow.FlowResultType.CREATE_ENTRY:
data = result.copy() data = result.copy()
@ -237,7 +237,7 @@ class LoginFlowBaseView(HomeAssistantView):
self, self,
request: web.Request, request: web.Request,
client_id: str, client_id: str,
result: data_entry_flow.FlowResult, result: AuthFlowResult,
) -> web.Response: ) -> web.Response:
"""Convert the flow result to a response.""" """Convert the flow result to a response."""
if result["type"] != data_entry_flow.FlowResultType.CREATE_ENTRY: if result["type"] != data_entry_flow.FlowResultType.CREATE_ENTRY:
@ -297,7 +297,9 @@ class LoginFlowIndexView(LoginFlowBaseView):
vol.Schema( vol.Schema(
{ {
vol.Required("client_id"): str, vol.Required("client_id"): str,
vol.Required("handler"): vol.Any(str, list), vol.Required("handler"): vol.All(
[vol.Any(str, None)], vol.Length(2, 2), vol.Coerce(tuple)
),
vol.Required("redirect_uri"): str, vol.Required("redirect_uri"): str,
vol.Optional("type", default="authorize"): str, vol.Optional("type", default="authorize"): str,
} }
@ -312,15 +314,11 @@ class LoginFlowIndexView(LoginFlowBaseView):
if not indieauth.verify_client_id(client_id): if not indieauth.verify_client_id(client_id):
return self.json_message("Invalid client id", HTTPStatus.BAD_REQUEST) return self.json_message("Invalid client id", HTTPStatus.BAD_REQUEST)
handler: tuple[str, ...] | str handler: tuple[str, str] = tuple(data["handler"])
if isinstance(data["handler"], list):
handler = tuple(data["handler"])
else:
handler = data["handler"]
try: try:
result = await self._flow_mgr.async_init( result = await self._flow_mgr.async_init(
handler, # type: ignore[arg-type] handler,
context={ context={
"ip_address": ip_address(request.remote), # type: ignore[arg-type] "ip_address": ip_address(request.remote), # type: ignore[arg-type]
"credential_only": data.get("type") == "link_user", "credential_only": data.get("type") == "link_user",

View File

@ -182,7 +182,7 @@ class BaseZwaveJSFlow(ConfigEntryBaseFlow, ABC):
@property @property
@abstractmethod @abstractmethod
def flow_manager(self) -> FlowManager[ConfigFlowResult]: def flow_manager(self) -> FlowManager[ConfigFlowResult, str]:
"""Return the flow manager of the flow.""" """Return the flow manager of the flow."""
async def async_step_install_addon( async def async_step_install_addon(

View File

@ -1045,7 +1045,7 @@ class FlowCancelledError(Exception):
"""Error to indicate that a flow has been cancelled.""" """Error to indicate that a flow has been cancelled."""
class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]): class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult, str]):
"""Manage all the config entry flows that are in progress.""" """Manage all the config entry flows that are in progress."""
_flow_result = ConfigFlowResult _flow_result = ConfigFlowResult
@ -1171,7 +1171,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
async def async_finish_flow( async def async_finish_flow(
self, self,
flow: data_entry_flow.FlowHandler[ConfigFlowResult], flow: data_entry_flow.FlowHandler[ConfigFlowResult, str],
result: ConfigFlowResult, result: ConfigFlowResult,
) -> ConfigFlowResult: ) -> ConfigFlowResult:
"""Finish a config flow and add an entry.""" """Finish a config flow and add an entry."""
@ -1293,7 +1293,7 @@ class ConfigEntriesFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
async def async_post_init( async def async_post_init(
self, self,
flow: data_entry_flow.FlowHandler[ConfigFlowResult], flow: data_entry_flow.FlowHandler[ConfigFlowResult, str],
result: ConfigFlowResult, result: ConfigFlowResult,
) -> None: ) -> None:
"""After a flow is initialised trigger new flow notifications.""" """After a flow is initialised trigger new flow notifications."""
@ -1940,7 +1940,7 @@ def _async_abort_entries_match(
raise data_entry_flow.AbortFlow("already_configured") raise data_entry_flow.AbortFlow("already_configured")
class ConfigEntryBaseFlow(data_entry_flow.FlowHandler[ConfigFlowResult]): class ConfigEntryBaseFlow(data_entry_flow.FlowHandler[ConfigFlowResult, str]):
"""Base class for config and option flows.""" """Base class for config and option flows."""
_flow_result = ConfigFlowResult _flow_result = ConfigFlowResult
@ -2292,7 +2292,7 @@ class ConfigFlow(ConfigEntryBaseFlow):
return self.async_abort(reason=reason) return self.async_abort(reason=reason)
class OptionsFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]): class OptionsFlowManager(data_entry_flow.FlowManager[ConfigFlowResult, str]):
"""Flow to set options for a configuration entry.""" """Flow to set options for a configuration entry."""
_flow_result = ConfigFlowResult _flow_result = ConfigFlowResult
@ -2322,7 +2322,7 @@ class OptionsFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
async def async_finish_flow( async def async_finish_flow(
self, self,
flow: data_entry_flow.FlowHandler[ConfigFlowResult], flow: data_entry_flow.FlowHandler[ConfigFlowResult, str],
result: ConfigFlowResult, result: ConfigFlowResult,
) -> ConfigFlowResult: ) -> ConfigFlowResult:
"""Finish an options flow and update options for configuration entry. """Finish an options flow and update options for configuration entry.
@ -2344,7 +2344,7 @@ class OptionsFlowManager(data_entry_flow.FlowManager[ConfigFlowResult]):
return result return result
async def _async_setup_preview( async def _async_setup_preview(
self, flow: data_entry_flow.FlowHandler[ConfigFlowResult] self, flow: data_entry_flow.FlowHandler[ConfigFlowResult, str]
) -> None: ) -> None:
"""Set up preview for an option flow handler.""" """Set up preview for an option flow handler."""
entry = self._async_get_config_entry(flow.handler) entry = self._async_get_config_entry(flow.handler)

View File

@ -85,7 +85,8 @@ STEP_ID_OPTIONAL_STEPS = {
} }
_FlowResultT = TypeVar("_FlowResultT", bound="FlowResult", default="FlowResult") _FlowResultT = TypeVar("_FlowResultT", bound="FlowResult[Any]", default="FlowResult")
_HandlerT = TypeVar("_HandlerT", default=str)
@dataclass(slots=True) @dataclass(slots=True)
@ -138,7 +139,7 @@ class AbortFlow(FlowError):
self.description_placeholders = description_placeholders self.description_placeholders = description_placeholders
class FlowResult(TypedDict, total=False): class FlowResult(TypedDict, Generic[_HandlerT], total=False):
"""Typed result dict.""" """Typed result dict."""
context: dict[str, Any] context: dict[str, Any]
@ -149,7 +150,7 @@ class FlowResult(TypedDict, total=False):
errors: dict[str, str] | None errors: dict[str, str] | None
extra: str extra: str
flow_id: Required[str] flow_id: Required[str]
handler: Required[str] handler: Required[_HandlerT]
last_step: bool | None last_step: bool | None
menu_options: list[str] | dict[str, str] menu_options: list[str] | dict[str, str]
options: Mapping[str, Any] options: Mapping[str, Any]
@ -189,7 +190,7 @@ def _map_error_to_schema_errors(
schema_errors[path_part_str] = error.error_message schema_errors[path_part_str] = error.error_message
class FlowManager(abc.ABC, Generic[_FlowResultT]): class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
"""Manage all the flows that are in progress.""" """Manage all the flows that are in progress."""
_flow_result: Callable[..., _FlowResultT] = FlowResult # type: ignore[assignment] _flow_result: Callable[..., _FlowResultT] = FlowResult # type: ignore[assignment]
@ -200,19 +201,23 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
) -> None: ) -> None:
"""Initialize the flow manager.""" """Initialize the flow manager."""
self.hass = hass self.hass = hass
self._preview: set[str] = set() self._preview: set[_HandlerT] = set()
self._progress: dict[str, FlowHandler[_FlowResultT]] = {} self._progress: dict[str, FlowHandler[_FlowResultT, _HandlerT]] = {}
self._handler_progress_index: dict[str, set[FlowHandler[_FlowResultT]]] = {} self._handler_progress_index: dict[
self._init_data_process_index: dict[type, set[FlowHandler[_FlowResultT]]] = {} _HandlerT, set[FlowHandler[_FlowResultT, _HandlerT]]
] = {}
self._init_data_process_index: dict[
type, set[FlowHandler[_FlowResultT, _HandlerT]]
] = {}
@abc.abstractmethod @abc.abstractmethod
async def async_create_flow( async def async_create_flow(
self, self,
handler_key: str, handler_key: _HandlerT,
*, *,
context: dict[str, Any] | None = None, context: dict[str, Any] | None = None,
data: dict[str, Any] | None = None, data: dict[str, Any] | None = None,
) -> FlowHandler[_FlowResultT]: ) -> FlowHandler[_FlowResultT, _HandlerT]:
"""Create a flow for specified handler. """Create a flow for specified handler.
Handler key is the domain of the component that we want to set up. Handler key is the domain of the component that we want to set up.
@ -220,18 +225,18 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
@abc.abstractmethod @abc.abstractmethod
async def async_finish_flow( async def async_finish_flow(
self, flow: FlowHandler[_FlowResultT], result: _FlowResultT self, flow: FlowHandler[_FlowResultT, _HandlerT], result: _FlowResultT
) -> _FlowResultT: ) -> _FlowResultT:
"""Finish a data entry flow.""" """Finish a data entry flow."""
async def async_post_init( async def async_post_init(
self, flow: FlowHandler[_FlowResultT], result: _FlowResultT self, flow: FlowHandler[_FlowResultT, _HandlerT], result: _FlowResultT
) -> None: ) -> None:
"""Entry has finished executing its first step asynchronously.""" """Entry has finished executing its first step asynchronously."""
@callback @callback
def async_has_matching_flow( def async_has_matching_flow(
self, handler: str, match_context: dict[str, Any], data: Any self, handler: _HandlerT, match_context: dict[str, Any], data: Any
) -> bool: ) -> bool:
"""Check if an existing matching flow is in progress. """Check if an existing matching flow is in progress.
@ -265,7 +270,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
@callback @callback
def async_progress_by_handler( def async_progress_by_handler(
self, self,
handler: str, handler: _HandlerT,
include_uninitialized: bool = False, include_uninitialized: bool = False,
match_context: dict[str, Any] | None = None, match_context: dict[str, Any] | None = None,
) -> list[_FlowResultT]: ) -> list[_FlowResultT]:
@ -298,8 +303,8 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
@callback @callback
def _async_progress_by_handler( def _async_progress_by_handler(
self, handler: str, match_context: dict[str, Any] | None self, handler: _HandlerT, match_context: dict[str, Any] | None
) -> list[FlowHandler[_FlowResultT]]: ) -> list[FlowHandler[_FlowResultT, _HandlerT]]:
"""Return the flows in progress by handler. """Return the flows in progress by handler.
If match_context is specified, only return flows with a context that If match_context is specified, only return flows with a context that
@ -315,7 +320,11 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
] ]
async def async_init( async def async_init(
self, handler: str, *, context: dict[str, Any] | None = None, data: Any = None self,
handler: _HandlerT,
*,
context: dict[str, Any] | None = None,
data: Any = None,
) -> _FlowResultT: ) -> _FlowResultT:
"""Start a data entry flow.""" """Start a data entry flow."""
if context is None: if context is None:
@ -445,7 +454,9 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
self._async_remove_flow_progress(flow_id) self._async_remove_flow_progress(flow_id)
@callback @callback
def _async_add_flow_progress(self, flow: FlowHandler[_FlowResultT]) -> None: def _async_add_flow_progress(
self, flow: FlowHandler[_FlowResultT, _HandlerT]
) -> None:
"""Add a flow to in progress.""" """Add a flow to in progress."""
if flow.init_data is not None: if flow.init_data is not None:
init_data_type = type(flow.init_data) init_data_type = type(flow.init_data)
@ -454,7 +465,9 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
self._handler_progress_index.setdefault(flow.handler, set()).add(flow) self._handler_progress_index.setdefault(flow.handler, set()).add(flow)
@callback @callback
def _async_remove_flow_from_index(self, flow: FlowHandler[_FlowResultT]) -> None: def _async_remove_flow_from_index(
self, flow: FlowHandler[_FlowResultT, _HandlerT]
) -> None:
"""Remove a flow from in progress.""" """Remove a flow from in progress."""
if flow.init_data is not None: if flow.init_data is not None:
init_data_type = type(flow.init_data) init_data_type = type(flow.init_data)
@ -480,7 +493,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
async def _async_handle_step( async def _async_handle_step(
self, self,
flow: FlowHandler[_FlowResultT], flow: FlowHandler[_FlowResultT, _HandlerT],
step_id: str, step_id: str,
user_input: dict | BaseServiceInfo | None, user_input: dict | BaseServiceInfo | None,
) -> _FlowResultT: ) -> _FlowResultT:
@ -557,7 +570,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
return result return result
def _raise_if_step_does_not_exist( def _raise_if_step_does_not_exist(
self, flow: FlowHandler[_FlowResultT], step_id: str self, flow: FlowHandler[_FlowResultT, _HandlerT], step_id: str
) -> None: ) -> None:
"""Raise if the step does not exist.""" """Raise if the step does not exist."""
method = f"async_step_{step_id}" method = f"async_step_{step_id}"
@ -568,7 +581,9 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
f"Handler {self.__class__.__name__} doesn't support step {step_id}" f"Handler {self.__class__.__name__} doesn't support step {step_id}"
) )
async def _async_setup_preview(self, flow: FlowHandler[_FlowResultT]) -> None: async def _async_setup_preview(
self, flow: FlowHandler[_FlowResultT, _HandlerT]
) -> None:
"""Set up preview for a flow handler.""" """Set up preview for a flow handler."""
if flow.handler not in self._preview: if flow.handler not in self._preview:
self._preview.add(flow.handler) self._preview.add(flow.handler)
@ -576,7 +591,9 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
@callback @callback
def _async_flow_handler_to_flow_result( def _async_flow_handler_to_flow_result(
self, flows: Iterable[FlowHandler[_FlowResultT]], include_uninitialized: bool self,
flows: Iterable[FlowHandler[_FlowResultT, _HandlerT]],
include_uninitialized: bool,
) -> list[_FlowResultT]: ) -> list[_FlowResultT]:
"""Convert a list of FlowHandler to a partial FlowResult that can be serialized.""" """Convert a list of FlowHandler to a partial FlowResult that can be serialized."""
results = [] results = []
@ -594,7 +611,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
return results return results
class FlowHandler(Generic[_FlowResultT]): class FlowHandler(Generic[_FlowResultT, _HandlerT]):
"""Handle a data entry flow.""" """Handle a data entry flow."""
_flow_result: Callable[..., _FlowResultT] = FlowResult # type: ignore[assignment] _flow_result: Callable[..., _FlowResultT] = FlowResult # type: ignore[assignment]
@ -606,7 +623,7 @@ class FlowHandler(Generic[_FlowResultT]):
# and removes the need for constant None checks or asserts. # and removes the need for constant None checks or asserts.
flow_id: str = None # type: ignore[assignment] flow_id: str = None # type: ignore[assignment]
hass: HomeAssistant = None # type: ignore[assignment] hass: HomeAssistant = None # type: ignore[assignment]
handler: str = None # type: ignore[assignment] handler: _HandlerT = None # type: ignore[assignment]
# Ensure the attribute has a subscriptable, but immutable, default value. # Ensure the attribute has a subscriptable, but immutable, default value.
context: dict[str, Any] = MappingProxyType({}) # type: ignore[assignment] context: dict[str, Any] = MappingProxyType({}) # type: ignore[assignment]

View File

@ -17,7 +17,7 @@ from . import config_validation as cv
_FlowManagerT = TypeVar( _FlowManagerT = TypeVar(
"_FlowManagerT", "_FlowManagerT",
bound=data_entry_flow.FlowManager[Any], bound="data_entry_flow.FlowManager[Any]",
default=data_entry_flow.FlowManager, default=data_entry_flow.FlowManager,
) )
@ -61,7 +61,7 @@ class FlowManagerIndexView(_BaseFlowManagerView[_FlowManagerT]):
@RequestDataValidator( @RequestDataValidator(
vol.Schema( vol.Schema(
{ {
vol.Required("handler"): vol.Any(str, list), vol.Required("handler"): str,
vol.Optional("show_advanced_options", default=False): cv.boolean, vol.Optional("show_advanced_options", default=False): cv.boolean,
}, },
extra=vol.ALLOW_EXTRA, extra=vol.ALLOW_EXTRA,
@ -79,14 +79,9 @@ class FlowManagerIndexView(_BaseFlowManagerView[_FlowManagerT]):
self, request: web.Request, data: dict[str, Any] self, request: web.Request, data: dict[str, Any]
) -> web.Response: ) -> web.Response:
"""Handle a POST request.""" """Handle a POST request."""
if isinstance(data["handler"], list):
handler = tuple(data["handler"])
else:
handler = data["handler"]
try: try:
result = await self._flow_mgr.async_init( result = await self._flow_mgr.async_init(
handler, # type: ignore[arg-type] data["handler"],
context=self.get_context(data), context=self.get_context(data),
) )
except data_entry_flow.UnknownHandler: except data_entry_flow.UnknownHandler: