Use ConfigFlow.has_matching_flow to deduplicate samsungtv flows (#127235)

This commit is contained in:
Erik Montnemery 2024-10-01 17:56:38 +02:00 committed by GitHub
parent 1c11229510
commit 4060705d87
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 81 additions and 6 deletions

View File

@ -5,7 +5,7 @@ from __future__ import annotations
from collections.abc import Mapping from collections.abc import Mapping
from functools import partial from functools import partial
import socket import socket
from typing import Any from typing import Any, Self
from urllib.parse import urlparse from urllib.parse import urlparse
import getmac import getmac
@ -425,10 +425,12 @@ class SamsungTVConfigFlow(ConfigFlow, domain=DOMAIN):
@callback @callback
def _async_abort_if_host_already_in_progress(self) -> None: def _async_abort_if_host_already_in_progress(self) -> None:
self.context[CONF_HOST] = self._host if self.hass.config_entries.flow.async_has_matching_flow(self):
for progress in self._async_in_progress(): raise AbortFlow("already_in_progress")
if progress.get("context", {}).get(CONF_HOST) == self._host:
raise AbortFlow("already_in_progress") def is_matching(self, other_flow: Self) -> bool:
"""Return True if other_flow is matching this flow."""
return other_flow._host == self._host # noqa: SLF001
@callback @callback
def _abort_if_manufacturer_is_not_samsung(self) -> None: def _abort_if_manufacturer_is_not_samsung(self) -> None:

View File

@ -22,6 +22,7 @@ from websockets.exceptions import (
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.components import dhcp, ssdp, zeroconf from homeassistant.components import dhcp, ssdp, zeroconf
from homeassistant.components.samsungtv.config_flow import SamsungTVConfigFlow
from homeassistant.components.samsungtv.const import ( from homeassistant.components.samsungtv.const import (
CONF_MANUFACTURER, CONF_MANUFACTURER,
CONF_SESSION_ID, CONF_SESSION_ID,
@ -56,7 +57,7 @@ from homeassistant.const import (
CONF_TOKEN, CONF_TOKEN,
) )
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType from homeassistant.data_entry_flow import BaseServiceInfo, FlowResultType
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from .const import ( from .const import (
@ -982,6 +983,78 @@ async def test_dhcp_wired(hass: HomeAssistant, rest_api: Mock) -> None:
assert result["result"].unique_id == "be9554b9-c9fb-41f4-8920-22da015376a4" assert result["result"].unique_id == "be9554b9-c9fb-41f4-8920-22da015376a4"
@pytest.mark.usefixtures("remotews", "rest_api_non_ssl_only", "remoteencws_failing")
@pytest.mark.parametrize(
("source1", "data1", "source2", "data2", "is_matching_result"),
[
(
config_entries.SOURCE_DHCP,
MOCK_DHCP_DATA,
config_entries.SOURCE_DHCP,
MOCK_DHCP_DATA,
True,
),
(
config_entries.SOURCE_DHCP,
MOCK_DHCP_DATA,
config_entries.SOURCE_ZEROCONF,
MOCK_ZEROCONF_DATA,
False,
),
(
config_entries.SOURCE_ZEROCONF,
MOCK_ZEROCONF_DATA,
config_entries.SOURCE_DHCP,
MOCK_DHCP_DATA,
False,
),
(
config_entries.SOURCE_ZEROCONF,
MOCK_ZEROCONF_DATA,
config_entries.SOURCE_ZEROCONF,
MOCK_ZEROCONF_DATA,
True,
),
],
)
async def test_dhcp_zeroconf_already_in_progress(
hass: HomeAssistant,
source1: str,
data1: BaseServiceInfo,
source2: str,
data2: BaseServiceInfo,
is_matching_result: bool,
) -> None:
"""Test starting a flow from dhcp or zeroconf when already in progress."""
# confirm to add the entry
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": source1}, data=data1
)
await hass.async_block_till_done()
assert result["type"] is FlowResultType.FORM
assert result["step_id"] == "confirm"
real_is_matching = SamsungTVConfigFlow.is_matching
return_values = []
def is_matching(self, other_flow) -> bool:
return_values.append(real_is_matching(self, other_flow))
return return_values[-1]
with patch.object(
SamsungTVConfigFlow, "is_matching", wraps=is_matching, autospec=True
):
# confirm to add the entry
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": source2}, data=data2
)
await hass.async_block_till_done()
assert result["type"] is FlowResultType.ABORT
assert result["reason"] == RESULT_ALREADY_IN_PROGRESS
# Ensure the is_matching method returned the expected value
assert return_values == [is_matching_result]
@pytest.mark.usefixtures("remotews", "rest_api", "remoteencws_failing") @pytest.mark.usefixtures("remotews", "rest_api", "remoteencws_failing")
async def test_zeroconf(hass: HomeAssistant) -> None: async def test_zeroconf(hass: HomeAssistant) -> None:
"""Test starting a flow from zeroconf.""" """Test starting a flow from zeroconf."""