Use ConfigFlow.has_matching_flow to deduplicate yeelight flows (#127165)

This commit is contained in:
Erik Montnemery 2024-10-02 08:25:46 +02:00 committed by GitHub
parent e3e68dad36
commit 375d47ee3a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 25 additions and 7 deletions

View File

@ -3,7 +3,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from typing import Any from typing import Any, Self
from urllib.parse import urlparse from urllib.parse import urlparse
import voluptuous as vol import voluptuous as vol
@ -53,7 +53,7 @@ class YeelightConfigFlow(ConfigFlow, domain=DOMAIN):
VERSION = 1 VERSION = 1
_discovered_ip: str _discovered_ip: str = ""
_discovered_model: str _discovered_model: str
@staticmethod @staticmethod
@ -119,9 +119,7 @@ class YeelightConfigFlow(ConfigFlow, domain=DOMAIN):
async def _async_handle_discovery(self) -> ConfigFlowResult: async def _async_handle_discovery(self) -> ConfigFlowResult:
"""Handle any discovery.""" """Handle any discovery."""
self.context[CONF_HOST] = self._discovered_ip if self.hass.config_entries.flow.async_has_matching_flow(self):
for progress in self._async_in_progress():
if progress.get("context", {}).get(CONF_HOST) == self._discovered_ip:
return self.async_abort(reason="already_in_progress") return self.async_abort(reason="already_in_progress")
self._async_abort_entries_match({CONF_HOST: self._discovered_ip}) self._async_abort_entries_match({CONF_HOST: self._discovered_ip})
@ -140,6 +138,10 @@ class YeelightConfigFlow(ConfigFlow, domain=DOMAIN):
) )
return await self.async_step_discovery_confirm() return await self.async_step_discovery_confirm()
def is_matching(self, other_flow: Self) -> bool:
"""Return True if other_flow is matching this flow."""
return other_flow._discovered_ip == self._discovered_ip # noqa: SLF001
async def async_step_discovery_confirm( async def async_step_discovery_confirm(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult: ) -> ConfigFlowResult:

View File

@ -7,7 +7,11 @@ import pytest
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.components import dhcp, ssdp, zeroconf from homeassistant.components import dhcp, ssdp, zeroconf
from homeassistant.components.yeelight.config_flow import MODEL_UNKNOWN, CannotConnect from homeassistant.components.yeelight.config_flow import (
MODEL_UNKNOWN,
CannotConnect,
YeelightConfigFlow,
)
from homeassistant.components.yeelight.const import ( from homeassistant.components.yeelight.const import (
CONF_DETECTED_MODEL, CONF_DETECTED_MODEL,
CONF_MODE_MUSIC, CONF_MODE_MUSIC,
@ -503,10 +507,20 @@ async def test_discovered_by_homekit_and_dhcp(hass: HomeAssistant) -> None:
assert result["type"] is FlowResultType.FORM assert result["type"] is FlowResultType.FORM
assert result["errors"] is None assert result["errors"] is None
real_is_matching = YeelightConfigFlow.is_matching
return_values = []
def is_matching(self, other_flow) -> bool:
return_values.append(real_is_matching(self, other_flow))
return return_values[-1]
with ( with (
_patch_discovery(), _patch_discovery(),
_patch_discovery_interval(), _patch_discovery_interval(),
patch(f"{MODULE_CONFIG_FLOW}.AsyncBulb", return_value=mocked_bulb), patch(f"{MODULE_CONFIG_FLOW}.AsyncBulb", return_value=mocked_bulb),
patch.object(
YeelightConfigFlow, "is_matching", wraps=is_matching, autospec=True
),
): ):
result2 = await hass.config_entries.flow.async_init( result2 = await hass.config_entries.flow.async_init(
DOMAIN, DOMAIN,
@ -518,6 +532,8 @@ async def test_discovered_by_homekit_and_dhcp(hass: HomeAssistant) -> None:
await hass.async_block_till_done() await hass.async_block_till_done()
assert result2["type"] is FlowResultType.ABORT assert result2["type"] is FlowResultType.ABORT
assert result2["reason"] == "already_in_progress" assert result2["reason"] == "already_in_progress"
# Ensure the is_matching method returned True
assert return_values == [True]
with ( with (
_patch_discovery(), _patch_discovery(),