Make FlowResult a generic type (#111952)

This commit is contained in:
Erik Montnemery
2024-03-07 12:41:14 +01:00
committed by GitHub
parent 008e025d5c
commit 82efb3d35b
13 changed files with 95 additions and 80 deletions

View File

@@ -85,7 +85,8 @@ STEP_ID_OPTIONAL_STEPS = {
}
_FlowResultT = TypeVar("_FlowResultT", bound="FlowResult", default="FlowResult")
_FlowResultT = TypeVar("_FlowResultT", bound="FlowResult[Any]", default="FlowResult")
_HandlerT = TypeVar("_HandlerT", default=str)
@dataclass(slots=True)
@@ -138,7 +139,7 @@ class AbortFlow(FlowError):
self.description_placeholders = description_placeholders
class FlowResult(TypedDict, total=False):
class FlowResult(TypedDict, Generic[_HandlerT], total=False):
"""Typed result dict."""
context: dict[str, Any]
@@ -149,7 +150,7 @@ class FlowResult(TypedDict, total=False):
errors: dict[str, str] | None
extra: str
flow_id: Required[str]
handler: Required[str]
handler: Required[_HandlerT]
last_step: bool | None
menu_options: list[str] | dict[str, str]
options: Mapping[str, Any]
@@ -189,7 +190,7 @@ def _map_error_to_schema_errors(
schema_errors[path_part_str] = error.error_message
class FlowManager(abc.ABC, Generic[_FlowResultT]):
class FlowManager(abc.ABC, Generic[_FlowResultT, _HandlerT]):
"""Manage all the flows that are in progress."""
_flow_result: Callable[..., _FlowResultT] = FlowResult # type: ignore[assignment]
@@ -200,19 +201,23 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
) -> None:
"""Initialize the flow manager."""
self.hass = hass
self._preview: set[str] = set()
self._progress: dict[str, FlowHandler[_FlowResultT]] = {}
self._handler_progress_index: dict[str, set[FlowHandler[_FlowResultT]]] = {}
self._init_data_process_index: dict[type, set[FlowHandler[_FlowResultT]]] = {}
self._preview: set[_HandlerT] = set()
self._progress: dict[str, FlowHandler[_FlowResultT, _HandlerT]] = {}
self._handler_progress_index: dict[
_HandlerT, set[FlowHandler[_FlowResultT, _HandlerT]]
] = {}
self._init_data_process_index: dict[
type, set[FlowHandler[_FlowResultT, _HandlerT]]
] = {}
@abc.abstractmethod
async def async_create_flow(
self,
handler_key: str,
handler_key: _HandlerT,
*,
context: dict[str, Any] | None = None,
data: dict[str, Any] | None = None,
) -> FlowHandler[_FlowResultT]:
) -> FlowHandler[_FlowResultT, _HandlerT]:
"""Create a flow for specified handler.
Handler key is the domain of the component that we want to set up.
@@ -220,18 +225,18 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
@abc.abstractmethod
async def async_finish_flow(
self, flow: FlowHandler[_FlowResultT], result: _FlowResultT
self, flow: FlowHandler[_FlowResultT, _HandlerT], result: _FlowResultT
) -> _FlowResultT:
"""Finish a data entry flow."""
async def async_post_init(
self, flow: FlowHandler[_FlowResultT], result: _FlowResultT
self, flow: FlowHandler[_FlowResultT, _HandlerT], result: _FlowResultT
) -> None:
"""Entry has finished executing its first step asynchronously."""
@callback
def async_has_matching_flow(
self, handler: str, match_context: dict[str, Any], data: Any
self, handler: _HandlerT, match_context: dict[str, Any], data: Any
) -> bool:
"""Check if an existing matching flow is in progress.
@@ -265,7 +270,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
@callback
def async_progress_by_handler(
self,
handler: str,
handler: _HandlerT,
include_uninitialized: bool = False,
match_context: dict[str, Any] | None = None,
) -> list[_FlowResultT]:
@@ -298,8 +303,8 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
@callback
def _async_progress_by_handler(
self, handler: str, match_context: dict[str, Any] | None
) -> list[FlowHandler[_FlowResultT]]:
self, handler: _HandlerT, match_context: dict[str, Any] | None
) -> list[FlowHandler[_FlowResultT, _HandlerT]]:
"""Return the flows in progress by handler.
If match_context is specified, only return flows with a context that
@@ -315,7 +320,11 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
]
async def async_init(
self, handler: str, *, context: dict[str, Any] | None = None, data: Any = None
self,
handler: _HandlerT,
*,
context: dict[str, Any] | None = None,
data: Any = None,
) -> _FlowResultT:
"""Start a data entry flow."""
if context is None:
@@ -445,7 +454,9 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
self._async_remove_flow_progress(flow_id)
@callback
def _async_add_flow_progress(self, flow: FlowHandler[_FlowResultT]) -> None:
def _async_add_flow_progress(
self, flow: FlowHandler[_FlowResultT, _HandlerT]
) -> None:
"""Add a flow to in progress."""
if flow.init_data is not None:
init_data_type = type(flow.init_data)
@@ -454,7 +465,9 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
self._handler_progress_index.setdefault(flow.handler, set()).add(flow)
@callback
def _async_remove_flow_from_index(self, flow: FlowHandler[_FlowResultT]) -> None:
def _async_remove_flow_from_index(
self, flow: FlowHandler[_FlowResultT, _HandlerT]
) -> None:
"""Remove a flow from in progress."""
if flow.init_data is not None:
init_data_type = type(flow.init_data)
@@ -480,7 +493,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
async def _async_handle_step(
self,
flow: FlowHandler[_FlowResultT],
flow: FlowHandler[_FlowResultT, _HandlerT],
step_id: str,
user_input: dict | BaseServiceInfo | None,
) -> _FlowResultT:
@@ -557,7 +570,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
return result
def _raise_if_step_does_not_exist(
self, flow: FlowHandler[_FlowResultT], step_id: str
self, flow: FlowHandler[_FlowResultT, _HandlerT], step_id: str
) -> None:
"""Raise if the step does not exist."""
method = f"async_step_{step_id}"
@@ -568,7 +581,9 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
f"Handler {self.__class__.__name__} doesn't support step {step_id}"
)
async def _async_setup_preview(self, flow: FlowHandler[_FlowResultT]) -> None:
async def _async_setup_preview(
self, flow: FlowHandler[_FlowResultT, _HandlerT]
) -> None:
"""Set up preview for a flow handler."""
if flow.handler not in self._preview:
self._preview.add(flow.handler)
@@ -576,7 +591,9 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
@callback
def _async_flow_handler_to_flow_result(
self, flows: Iterable[FlowHandler[_FlowResultT]], include_uninitialized: bool
self,
flows: Iterable[FlowHandler[_FlowResultT, _HandlerT]],
include_uninitialized: bool,
) -> list[_FlowResultT]:
"""Convert a list of FlowHandler to a partial FlowResult that can be serialized."""
results = []
@@ -594,7 +611,7 @@ class FlowManager(abc.ABC, Generic[_FlowResultT]):
return results
class FlowHandler(Generic[_FlowResultT]):
class FlowHandler(Generic[_FlowResultT, _HandlerT]):
"""Handle a data entry flow."""
_flow_result: Callable[..., _FlowResultT] = FlowResult # type: ignore[assignment]
@@ -606,7 +623,7 @@ class FlowHandler(Generic[_FlowResultT]):
# and removes the need for constant None checks or asserts.
flow_id: str = None # type: ignore[assignment]
hass: HomeAssistant = None # type: ignore[assignment]
handler: str = None # type: ignore[assignment]
handler: _HandlerT = None # type: ignore[assignment]
# Ensure the attribute has a subscriptable, but immutable, default value.
context: dict[str, Any] = MappingProxyType({}) # type: ignore[assignment]