diff --git a/homeassistant/components/demo/repairs.py b/homeassistant/components/demo/repairs.py index 2d7c8b4cbcc..cddc937a71a 100644 --- a/homeassistant/components/demo/repairs.py +++ b/homeassistant/components/demo/repairs.py @@ -6,6 +6,7 @@ import voluptuous as vol from homeassistant import data_entry_flow from homeassistant.components.repairs import ConfirmRepairFlow, RepairsFlow +from homeassistant.core import HomeAssistant class DemoFixFlow(RepairsFlow): @@ -28,7 +29,11 @@ class DemoFixFlow(RepairsFlow): return self.async_show_form(step_id="confirm", data_schema=vol.Schema({})) -async def async_create_fix_flow(hass, issue_id): +async def async_create_fix_flow( + hass: HomeAssistant, + issue_id: str, + data: dict[str, str | int | float | None] | None, +) -> RepairsFlow: """Create flow.""" if issue_id == "bad_psu": # The bad_psu issue doesn't have its own flow diff --git a/homeassistant/components/flunearyou/repairs.py b/homeassistant/components/flunearyou/repairs.py index f48085ba623..df81a1ae576 100644 --- a/homeassistant/components/flunearyou/repairs.py +++ b/homeassistant/components/flunearyou/repairs.py @@ -36,7 +36,9 @@ class FluNearYouFixFlow(RepairsFlow): async def async_create_fix_flow( - hass: HomeAssistant, issue_id: str -) -> FluNearYouFixFlow: + hass: HomeAssistant, + issue_id: str, + data: dict[str, str | int | float | None] | None, +) -> RepairsFlow: """Create flow.""" return FluNearYouFixFlow() diff --git a/homeassistant/components/repairs/issue_handler.py b/homeassistant/components/repairs/issue_handler.py index 23f37754ffe..5695e99998b 100644 --- a/homeassistant/components/repairs/issue_handler.py +++ b/homeassistant/components/repairs/issue_handler.py @@ -64,10 +64,14 @@ class RepairsFlowManager(data_entry_flow.FlowManager): platforms: dict[str, RepairsProtocol] = self.hass.data[DOMAIN]["platforms"] if handler_key not in platforms: - return ConfirmRepairFlow() - platform = platforms[handler_key] + flow: RepairsFlow = ConfirmRepairFlow() + else: + platform = platforms[handler_key] + flow = await platform.async_create_fix_flow(self.hass, issue_id, issue.data) - return await platform.async_create_fix_flow(self.hass, issue_id) + flow.issue_id = issue_id + flow.data = issue.data + return flow async def async_finish_flow( self, flow: data_entry_flow.FlowHandler, result: data_entry_flow.FlowResult @@ -109,6 +113,7 @@ def async_create_issue( *, issue_domain: str | None = None, breaks_in_ha_version: str | None = None, + data: dict[str, str | int | float | None] | None = None, is_fixable: bool, is_persistent: bool = False, learn_more_url: str | None = None, @@ -131,6 +136,7 @@ def async_create_issue( issue_id, issue_domain=issue_domain, breaks_in_ha_version=breaks_in_ha_version, + data=data, is_fixable=is_fixable, is_persistent=is_persistent, learn_more_url=learn_more_url, @@ -146,6 +152,7 @@ def create_issue( issue_id: str, *, breaks_in_ha_version: str | None = None, + data: dict[str, str | int | float | None] | None = None, is_fixable: bool, is_persistent: bool = False, learn_more_url: str | None = None, @@ -162,6 +169,7 @@ def create_issue( domain, issue_id, breaks_in_ha_version=breaks_in_ha_version, + data=data, is_fixable=is_fixable, is_persistent=is_persistent, learn_more_url=learn_more_url, diff --git a/homeassistant/components/repairs/issue_registry.py b/homeassistant/components/repairs/issue_registry.py index c7502ecf397..f9a15e0f165 100644 --- a/homeassistant/components/repairs/issue_registry.py +++ b/homeassistant/components/repairs/issue_registry.py @@ -27,6 +27,7 @@ class IssueEntry: active: bool breaks_in_ha_version: str | None created: datetime + data: dict[str, str | int | float | None] | None dismissed_version: str | None domain: str is_fixable: bool | None @@ -53,6 +54,7 @@ class IssueEntry: return { **result, "breaks_in_ha_version": self.breaks_in_ha_version, + "data": self.data, "is_fixable": self.is_fixable, "is_persistent": True, "issue_domain": self.issue_domain, @@ -106,6 +108,7 @@ class IssueRegistry: *, issue_domain: str | None = None, breaks_in_ha_version: str | None = None, + data: dict[str, str | int | float | None] | None = None, is_fixable: bool, is_persistent: bool, learn_more_url: str | None = None, @@ -120,6 +123,7 @@ class IssueRegistry: active=True, breaks_in_ha_version=breaks_in_ha_version, created=dt_util.utcnow(), + data=data, dismissed_version=None, domain=domain, is_fixable=is_fixable, @@ -142,6 +146,7 @@ class IssueRegistry: issue, active=True, breaks_in_ha_version=breaks_in_ha_version, + data=data, is_fixable=is_fixable, is_persistent=is_persistent, issue_domain=issue_domain, @@ -204,6 +209,7 @@ class IssueRegistry: active=True, breaks_in_ha_version=issue["breaks_in_ha_version"], created=created, + data=issue["data"], dismissed_version=issue["dismissed_version"], domain=issue["domain"], is_fixable=issue["is_fixable"], @@ -220,6 +226,7 @@ class IssueRegistry: active=False, breaks_in_ha_version=None, created=created, + data=None, dismissed_version=issue["dismissed_version"], domain=issue["domain"], is_fixable=None, diff --git a/homeassistant/components/repairs/models.py b/homeassistant/components/repairs/models.py index 2a6eeb15269..1022c50e1f2 100644 --- a/homeassistant/components/repairs/models.py +++ b/homeassistant/components/repairs/models.py @@ -19,11 +19,17 @@ class IssueSeverity(StrEnum): class RepairsFlow(data_entry_flow.FlowHandler): """Handle a flow for fixing an issue.""" + issue_id: str + data: dict[str, str | int | float | None] | None + class RepairsProtocol(Protocol): """Define the format of repairs platforms.""" async def async_create_fix_flow( - self, hass: HomeAssistant, issue_id: str + self, + hass: HomeAssistant, + issue_id: str, + data: dict[str, str | int | float | None] | None, ) -> RepairsFlow: """Create a flow to fix a fixable issue.""" diff --git a/homeassistant/components/repairs/websocket_api.py b/homeassistant/components/repairs/websocket_api.py index ff0ac5ba8f9..192c9f5ac66 100644 --- a/homeassistant/components/repairs/websocket_api.py +++ b/homeassistant/components/repairs/websocket_api.py @@ -64,7 +64,8 @@ def ws_list_issues( """Return a list of issues.""" def ws_dict(kv_pairs: list[tuple[Any, Any]]) -> dict[Any, Any]: - result = {k: v for k, v in kv_pairs if k not in ("active", "is_persistent")} + excluded_keys = ("active", "data", "is_persistent") + result = {k: v for k, v in kv_pairs if k not in excluded_keys} result["ignored"] = result["dismissed_version"] is not None result["created"] = result["created"].isoformat() return result diff --git a/tests/components/repairs/test_issue_registry.py b/tests/components/repairs/test_issue_registry.py index ff6c4b996da..76faafce1c7 100644 --- a/tests/components/repairs/test_issue_registry.py +++ b/tests/components/repairs/test_issue_registry.py @@ -51,6 +51,7 @@ async def test_load_issues(hass: HomeAssistant) -> None: }, { "breaks_in_ha_version": "2022.6", + "data": {"entry_id": "123"}, "domain": "test", "issue_id": "issue_4", "is_fixable": True, @@ -141,6 +142,7 @@ async def test_load_issues(hass: HomeAssistant) -> None: active=False, breaks_in_ha_version=None, created=issue1.created, + data=None, dismissed_version=issue1.dismissed_version, domain=issue1.domain, is_fixable=None, @@ -157,6 +159,7 @@ async def test_load_issues(hass: HomeAssistant) -> None: active=False, breaks_in_ha_version=None, created=issue2.created, + data=None, dismissed_version=issue2.dismissed_version, domain=issue2.domain, is_fixable=None, @@ -196,6 +199,7 @@ async def test_loading_issues_from_storage(hass: HomeAssistant, hass_storage) -> { "breaks_in_ha_version": "2022.6", "created": "2022-07-19T19:41:13.746514+00:00", + "data": {"entry_id": "123"}, "dismissed_version": None, "domain": "test", "issue_domain": "blubb", diff --git a/tests/components/repairs/test_websocket_api.py b/tests/components/repairs/test_websocket_api.py index a47a7a899ea..1cb83d81b06 100644 --- a/tests/components/repairs/test_websocket_api.py +++ b/tests/components/repairs/test_websocket_api.py @@ -37,6 +37,17 @@ DEFAULT_ISSUES = [ async def create_issues(hass, ws_client, issues=None): """Create issues.""" + + def api_issue(issue): + excluded_keys = ("data",) + return dict( + {key: issue[key] for key in issue if key not in excluded_keys}, + created=ANY, + dismissed_version=None, + ignored=False, + issue_domain=None, + ) + if issues is None: issues = DEFAULT_ISSUES @@ -46,6 +57,7 @@ async def create_issues(hass, ws_client, issues=None): issue["domain"], issue["issue_id"], breaks_in_ha_version=issue["breaks_in_ha_version"], + data=issue.get("data"), is_fixable=issue["is_fixable"], is_persistent=False, learn_more_url=issue["learn_more_url"], @@ -58,22 +70,17 @@ async def create_issues(hass, ws_client, issues=None): msg = await ws_client.receive_json() assert msg["success"] - assert msg["result"] == { - "issues": [ - dict( - issue, - created=ANY, - dismissed_version=None, - ignored=False, - issue_domain=None, - ) - for issue in issues - ] - } + assert msg["result"] == {"issues": [api_issue(issue) for issue in issues]} return issues +EXPECTED_DATA = { + "issue_1": None, + "issue_2": {"blah": "bleh"}, +} + + class MockFixFlow(RepairsFlow): """Handler for an issue fixing flow.""" @@ -82,6 +89,9 @@ class MockFixFlow(RepairsFlow): ) -> data_entry_flow.FlowResult: """Handle the first step of a fix flow.""" + assert self.issue_id in EXPECTED_DATA + assert self.data == EXPECTED_DATA[self.issue_id] + return await (self.async_step_custom_step()) async def async_step_custom_step( @@ -99,7 +109,10 @@ async def mock_repairs_integration(hass): """Mock a repairs integration.""" hass.config.components.add("fake_integration") - def async_create_fix_flow(hass, issue_id): + def async_create_fix_flow(hass, issue_id, data): + assert issue_id in EXPECTED_DATA + assert data == EXPECTED_DATA[issue_id] + return MockFixFlow() mock_platform( @@ -256,11 +269,18 @@ async def test_fix_issue( ws_client = await hass_ws_client(hass) client = await hass_client() - issues = [{**DEFAULT_ISSUES[0], "domain": domain}] + issues = [ + { + **DEFAULT_ISSUES[0], + "data": {"blah": "bleh"}, + "domain": domain, + "issue_id": "issue_2", + } + ] await create_issues(hass, ws_client, issues=issues) url = "/api/repairs/issues/fix" - resp = await client.post(url, json={"handler": domain, "issue_id": "issue_1"}) + resp = await client.post(url, json={"handler": domain, "issue_id": "issue_2"}) assert resp.status == HTTPStatus.OK data = await resp.json()