mirror of
https://github.com/home-assistant/core.git
synced 2025-11-08 18:39:30 +00:00
Make FlowResult a generic type (#111952)
This commit is contained in:
@@ -85,7 +85,8 @@ STEP_ID_OPTIONAL_STEPS = {
|
||||
}
|
||||
|
||||
|
||||
_FlowResultT = TypeVar("_FlowResultT", bound="FlowResult", default="FlowResult")
|
||||
_FlowResultT = TypeVar("_FlowResultT", bound="FlowResult[Any]", default="FlowResult")
|
||||
_HandlerT = TypeVar("_HandlerT", default=str)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
@@ -138,7 +139,7 @@ class AbortFlow(FlowError):
|
||||
self.description_placeholders = description_placeholders
|
||||
|
||||
|
||||
class FlowResult(TypedDict, total=False):
|
||||
class FlowResult(TypedDict, Generic[_HandlerT], total=False):
|
||||
"""Typed result dict."""
|
||||
|
||||
context: dict[str, Any]
|
||||
@@ -149,7 +150,7 @@ class FlowResult(TypedDict, total=False):
|
||||
errors: dict[str, str] | None
|
||||
extra: str
|
||||
flow_id: Required[str]
|
||||
handler: Required[str]
|
||||
handler: Required[_HandlerT]
|
||||
last_step: bool | None
|
||||
menu_options: list[str] | dict[str, str]
|
||||
options: Mapping[str, Any]
|
||||
@@ -189,7 +190,7 @@ def _map_error_to_schema_errors(
|
||||
schema_errors[path_part_str] = error.error_message
|
||||
|
||||
|
||||
class FlowManager(abc.ABC, Generic[_FlowResultT]):
|
||||
class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
|
||||
"""Manage all the flows that are in progress."""
|
||||
|
||||
_flow_result: Callable[..., _FlowResultT] = FlowResult # type: ignore[assignment]
|
||||
@@ -200,19 +201,23 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
|
||||
) -> None:
|
||||
"""Initialize the flow manager."""
|
||||
self.hass = hass
|
||||
self._preview: set[str] = set()
|
||||
self._progress: dict[str, FlowHandler[_FlowResultT]] = {}
|
||||
self._handler_progress_index: dict[str, set[FlowHandler[_FlowResultT]]] = {}
|
||||
self._init_data_process_index: dict[type, set[FlowHandler[_FlowResultT]]] = {}
|
||||
self._preview: set[_HandlerT] = set()
|
||||
self._progress: dict[str, FlowHandler[_FlowResultT, _HandlerT]] = {}
|
||||
self._handler_progress_index: dict[
|
||||
_HandlerT, set[FlowHandler[_FlowResultT, _HandlerT]]
|
||||
] = {}
|
||||
self._init_data_process_index: dict[
|
||||
type, set[FlowHandler[_FlowResultT, _HandlerT]]
|
||||
] = {}
|
||||
|
||||
@abc.abstractmethod
|
||||
async def async_create_flow(
|
||||
self,
|
||||
handler_key: str,
|
||||
handler_key: _HandlerT,
|
||||
*,
|
||||
context: dict[str, Any] | None = None,
|
||||
data: dict[str, Any] | None = None,
|
||||
) -> FlowHandler[_FlowResultT]:
|
||||
) -> FlowHandler[_FlowResultT, _HandlerT]:
|
||||
"""Create a flow for specified handler.
|
||||
|
||||
Handler key is the domain of the component that we want to set up.
|
||||
@@ -220,18 +225,18 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
|
||||
|
||||
@abc.abstractmethod
|
||||
async def async_finish_flow(
|
||||
self, flow: FlowHandler[_FlowResultT], result: _FlowResultT
|
||||
self, flow: FlowHandler[_FlowResultT, _HandlerT], result: _FlowResultT
|
||||
) -> _FlowResultT:
|
||||
"""Finish a data entry flow."""
|
||||
|
||||
async def async_post_init(
|
||||
self, flow: FlowHandler[_FlowResultT], result: _FlowResultT
|
||||
self, flow: FlowHandler[_FlowResultT, _HandlerT], result: _FlowResultT
|
||||
) -> None:
|
||||
"""Entry has finished executing its first step asynchronously."""
|
||||
|
||||
@callback
|
||||
def async_has_matching_flow(
|
||||
self, handler: str, match_context: dict[str, Any], data: Any
|
||||
self, handler: _HandlerT, match_context: dict[str, Any], data: Any
|
||||
) -> bool:
|
||||
"""Check if an existing matching flow is in progress.
|
||||
|
||||
@@ -265,7 +270,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
|
||||
@callback
|
||||
def async_progress_by_handler(
|
||||
self,
|
||||
handler: str,
|
||||
handler: _HandlerT,
|
||||
include_uninitialized: bool = False,
|
||||
match_context: dict[str, Any] | None = None,
|
||||
) -> list[_FlowResultT]:
|
||||
@@ -298,8 +303,8 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
|
||||
|
||||
@callback
|
||||
def _async_progress_by_handler(
|
||||
self, handler: str, match_context: dict[str, Any] | None
|
||||
) -> list[FlowHandler[_FlowResultT]]:
|
||||
self, handler: _HandlerT, match_context: dict[str, Any] | None
|
||||
) -> list[FlowHandler[_FlowResultT, _HandlerT]]:
|
||||
"""Return the flows in progress by handler.
|
||||
|
||||
If match_context is specified, only return flows with a context that
|
||||
@@ -315,7 +320,11 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
|
||||
]
|
||||
|
||||
async def async_init(
|
||||
self, handler: str, *, context: dict[str, Any] | None = None, data: Any = None
|
||||
self,
|
||||
handler: _HandlerT,
|
||||
*,
|
||||
context: dict[str, Any] | None = None,
|
||||
data: Any = None,
|
||||
) -> _FlowResultT:
|
||||
"""Start a data entry flow."""
|
||||
if context is None:
|
||||
@@ -445,7 +454,9 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
|
||||
self._async_remove_flow_progress(flow_id)
|
||||
|
||||
@callback
|
||||
def _async_add_flow_progress(self, flow: FlowHandler[_FlowResultT]) -> None:
|
||||
def _async_add_flow_progress(
|
||||
self, flow: FlowHandler[_FlowResultT, _HandlerT]
|
||||
) -> None:
|
||||
"""Add a flow to in progress."""
|
||||
if flow.init_data is not None:
|
||||
init_data_type = type(flow.init_data)
|
||||
@@ -454,7 +465,9 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
|
||||
self._handler_progress_index.setdefault(flow.handler, set()).add(flow)
|
||||
|
||||
@callback
|
||||
def _async_remove_flow_from_index(self, flow: FlowHandler[_FlowResultT]) -> None:
|
||||
def _async_remove_flow_from_index(
|
||||
self, flow: FlowHandler[_FlowResultT, _HandlerT]
|
||||
) -> None:
|
||||
"""Remove a flow from in progress."""
|
||||
if flow.init_data is not None:
|
||||
init_data_type = type(flow.init_data)
|
||||
@@ -480,7 +493,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
|
||||
|
||||
async def _async_handle_step(
|
||||
self,
|
||||
flow: FlowHandler[_FlowResultT],
|
||||
flow: FlowHandler[_FlowResultT, _HandlerT],
|
||||
step_id: str,
|
||||
user_input: dict | BaseServiceInfo | None,
|
||||
) -> _FlowResultT:
|
||||
@@ -557,7 +570,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
|
||||
return result
|
||||
|
||||
def _raise_if_step_does_not_exist(
|
||||
self, flow: FlowHandler[_FlowResultT], step_id: str
|
||||
self, flow: FlowHandler[_FlowResultT, _HandlerT], step_id: str
|
||||
) -> None:
|
||||
"""Raise if the step does not exist."""
|
||||
method = f"async_step_{step_id}"
|
||||
@@ -568,7 +581,9 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
|
||||
f"Handler {self.__class__.__name__} doesn't support step {step_id}"
|
||||
)
|
||||
|
||||
async def _async_setup_preview(self, flow: FlowHandler[_FlowResultT]) -> None:
|
||||
async def _async_setup_preview(
|
||||
self, flow: FlowHandler[_FlowResultT, _HandlerT]
|
||||
) -> None:
|
||||
"""Set up preview for a flow handler."""
|
||||
if flow.handler not in self._preview:
|
||||
self._preview.add(flow.handler)
|
||||
@@ -576,7 +591,9 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
|
||||
|
||||
@callback
|
||||
def _async_flow_handler_to_flow_result(
|
||||
self, flows: Iterable[FlowHandler[_FlowResultT]], include_uninitialized: bool
|
||||
self,
|
||||
flows: Iterable[FlowHandler[_FlowResultT, _HandlerT]],
|
||||
include_uninitialized: bool,
|
||||
) -> list[_FlowResultT]:
|
||||
"""Convert a list of FlowHandler to a partial FlowResult that can be serialized."""
|
||||
results = []
|
||||
@@ -594,7 +611,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
|
||||
return results
|
||||
|
||||
|
||||
class FlowHandler(Generic[_FlowResultT]):
|
||||
class FlowHandler(Generic[_FlowResultT, _HandlerT]):
|
||||
"""Handle a data entry flow."""
|
||||
|
||||
_flow_result: Callable[..., _FlowResultT] = FlowResult # type: ignore[assignment]
|
||||
@@ -606,7 +623,7 @@ class FlowHandler(Generic[_FlowResultT]):
|
||||
# and removes the need for constant None checks or asserts.
|
||||
flow_id: str = None # type: ignore[assignment]
|
||||
hass: HomeAssistant = None # type: ignore[assignment]
|
||||
handler: str = None # type: ignore[assignment]
|
||||
handler: _HandlerT = None # type: ignore[assignment]
|
||||
# Ensure the attribute has a subscriptable, but immutable, default value.
|
||||
context: dict[str, Any] = MappingProxyType({}) # type: ignore[assignment]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user