diff --git a/homeassistant/components/samsungtv/config_flow.py b/homeassistant/components/samsungtv/config_flow.py index 4fc24c5cc3e..76128a1f1dd 100644 --- a/homeassistant/components/samsungtv/config_flow.py +++ b/homeassistant/components/samsungtv/config_flow.py @@ -51,6 +51,11 @@ def _strip_uuid(udn): return udn[5:] if udn.startswith("uuid:") else udn +def _entry_is_complete(entry): + """Return True if the config entry information is complete.""" + return bool(entry.unique_id and entry.data.get(CONF_MAC)) + + class SamsungTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): """Handle a Samsung TV config flow.""" @@ -93,12 +98,19 @@ class SamsungTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): if not await self._async_get_and_check_device_info(): raise data_entry_flow.AbortFlow(RESULT_NOT_SUPPORTED) await self._async_set_unique_id_from_udn(raise_on_progress) + self._async_update_and_abort_for_matching_unique_id() async def _async_set_unique_id_from_udn(self, raise_on_progress=True): """Set the unique id from the udn.""" assert self._host is not None await self.async_set_unique_id(self._udn, raise_on_progress=raise_on_progress) - self._async_update_existing_host_entry(self._host) + if (entry := self._async_update_existing_host_entry()) and _entry_is_complete( + entry + ): + raise data_entry_flow.AbortFlow("already_configured") + + def _async_update_and_abort_for_matching_unique_id(self): + """Abort and update host and mac if we have it.""" updates = {CONF_HOST: self._host} if self._mac: updates[CONF_MAC] = self._mac @@ -178,37 +190,50 @@ class SamsungTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): return self.async_show_form(step_id="user", data_schema=DATA_SCHEMA) @callback - def _async_update_existing_host_entry(self, host): + def _async_update_existing_host_entry(self): + """Check existing entries and update them. + + Returns the existing entry if it was updated. + """ for entry in self._async_current_entries(include_ignore=False): - if entry.data[CONF_HOST] != host: + if entry.data[CONF_HOST] != self._host: continue entry_kw_args = {} if self.unique_id and entry.unique_id is None: entry_kw_args["unique_id"] = self.unique_id if self._mac and not entry.data.get(CONF_MAC): - data_copy = dict(entry.data) - data_copy[CONF_MAC] = self._mac - entry_kw_args["data"] = data_copy + entry_kw_args["data"] = {**entry.data, CONF_MAC: self._mac} if entry_kw_args: self.hass.config_entries.async_update_entry(entry, **entry_kw_args) - return entry + self.hass.async_create_task( + self.hass.config_entries.async_reload(entry.entry_id) + ) + return entry return None - async def _async_start_discovery(self): + async def _async_start_discovery_with_mac_address(self): """Start discovery.""" assert self._host is not None - if entry := self._async_update_existing_host_entry(self._host): - if entry.unique_id: - # Let the flow continue to fill the missing - # unique id as we may be able to obtain it - # in the next step - raise data_entry_flow.AbortFlow("already_configured") + if (entry := self._async_update_existing_host_entry()) and entry.unique_id: + # If we have the unique id and the mac we abort + # as we do not need anything else + raise data_entry_flow.AbortFlow("already_configured") + self._async_abort_if_host_already_in_progress() + @callback + def _async_abort_if_host_already_in_progress(self): self.context[CONF_HOST] = self._host for progress in self._async_in_progress(): if progress.get("context", {}).get(CONF_HOST) == self._host: raise data_entry_flow.AbortFlow("already_in_progress") + @callback + def _abort_if_manufacturer_is_not_samsung(self): + if not self._manufacturer or not self._manufacturer.lower().startswith( + "samsung" + ): + raise data_entry_flow.AbortFlow(RESULT_NOT_SUPPORTED) + async def async_step_ssdp(self, discovery_info: DiscoveryInfoType): """Handle a flow initialized by ssdp discovery.""" LOGGER.debug("Samsung device found via SSDP: %s", discovery_info) @@ -216,16 +241,14 @@ class SamsungTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): self._udn = _strip_uuid(discovery_info[ATTR_UPNP_UDN]) self._host = urlparse(discovery_info[ATTR_SSDP_LOCATION]).hostname await self._async_set_unique_id_from_udn() - await self._async_start_discovery() self._manufacturer = discovery_info[ATTR_UPNP_MANUFACTURER] - if not self._manufacturer or not self._manufacturer.lower().startswith( - "samsung" - ): - raise data_entry_flow.AbortFlow(RESULT_NOT_SUPPORTED) + self._abort_if_manufacturer_is_not_samsung() if not await self._async_get_and_check_device_info(): # If we cannot get device info for an SSDP discovery # its likely a legacy tv. self._name = self._title = self._model = model_name + self._async_update_and_abort_for_matching_unique_id() + self._async_abort_if_host_already_in_progress() self.context["title_placeholders"] = {"device": self._title} return await self.async_step_confirm() @@ -234,7 +257,7 @@ class SamsungTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): LOGGER.debug("Samsung device found via DHCP: %s", discovery_info) self._mac = discovery_info[MAC_ADDRESS] self._host = discovery_info[IP_ADDRESS] - await self._async_start_discovery() + await self._async_start_discovery_with_mac_address() await self._async_set_device_unique_id() self.context["title_placeholders"] = {"device": self._title} return await self.async_step_confirm() @@ -244,7 +267,7 @@ class SamsungTVConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): LOGGER.debug("Samsung device found via ZEROCONF: %s", discovery_info) self._mac = format_mac(discovery_info[ATTR_PROPERTIES]["deviceid"]) self._host = discovery_info[CONF_HOST] - await self._async_start_discovery() + await self._async_start_discovery_with_mac_address() await self._async_set_device_unique_id() self.context["title_placeholders"] = {"device": self._title} return await self.async_step_confirm() diff --git a/tests/components/samsungtv/test_config_flow.py b/tests/components/samsungtv/test_config_flow.py index 3830673b4cc..a0d2875ca59 100644 --- a/tests/components/samsungtv/test_config_flow.py +++ b/tests/components/samsungtv/test_config_flow.py @@ -894,12 +894,22 @@ async def test_update_missing_mac_unique_id_added_from_dhcp(hass, remotews: Mock """Test missing mac and unique id added.""" entry = MockConfigEntry(domain=DOMAIN, data=MOCK_OLD_ENTRY, unique_id=None) entry.add_to_hass(hass) - result = await hass.config_entries.flow.async_init( - DOMAIN, - context={"source": config_entries.SOURCE_DHCP}, - data=MOCK_DHCP_DATA, - ) - await hass.async_block_till_done() + with patch( + "homeassistant.components.samsungtv.async_setup", + return_value=True, + ) as mock_setup, patch( + "homeassistant.components.samsungtv.async_setup_entry", + return_value=True, + ) as mock_setup_entry: + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": config_entries.SOURCE_DHCP}, + data=MOCK_DHCP_DATA, + ) + await hass.async_block_till_done() + assert len(mock_setup.mock_calls) == 1 + assert len(mock_setup_entry.mock_calls) == 1 + assert result["type"] == "abort" assert result["reason"] == "already_configured" assert entry.data[CONF_MAC] == "aa:bb:cc:dd:ee:ff" @@ -910,18 +920,53 @@ async def test_update_missing_mac_unique_id_added_from_zeroconf(hass, remotews: """Test missing mac and unique id added.""" entry = MockConfigEntry(domain=DOMAIN, data=MOCK_OLD_ENTRY, unique_id=None) entry.add_to_hass(hass) - result = await hass.config_entries.flow.async_init( - DOMAIN, - context={"source": config_entries.SOURCE_ZEROCONF}, - data=MOCK_ZEROCONF_DATA, - ) - await hass.async_block_till_done() + with patch( + "homeassistant.components.samsungtv.async_setup", + return_value=True, + ) as mock_setup, patch( + "homeassistant.components.samsungtv.async_setup_entry", + return_value=True, + ) as mock_setup_entry: + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": config_entries.SOURCE_ZEROCONF}, + data=MOCK_ZEROCONF_DATA, + ) + await hass.async_block_till_done() + assert len(mock_setup.mock_calls) == 1 + assert len(mock_setup_entry.mock_calls) == 1 assert result["type"] == "abort" assert result["reason"] == "already_configured" assert entry.data[CONF_MAC] == "aa:bb:cc:dd:ee:ff" assert entry.unique_id == "be9554b9-c9fb-41f4-8920-22da015376a4" +async def test_update_missing_mac_unique_id_added_from_ssdp(hass, remotews: Mock): + """Test missing mac and unique id added via ssdp.""" + entry = MockConfigEntry(domain=DOMAIN, data=MOCK_OLD_ENTRY, unique_id=None) + entry.add_to_hass(hass) + with patch( + "homeassistant.components.samsungtv.async_setup", + return_value=True, + ) as mock_setup, patch( + "homeassistant.components.samsungtv.async_setup_entry", + return_value=True, + ) as mock_setup_entry: + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": config_entries.SOURCE_SSDP}, + data=MOCK_SSDP_DATA, + ) + await hass.async_block_till_done() + assert len(mock_setup.mock_calls) == 1 + assert len(mock_setup_entry.mock_calls) == 1 + + assert result["type"] == "abort" + assert result["reason"] == "already_configured" + assert entry.data[CONF_MAC] == "aa:bb:cc:dd:ee:ff" + assert entry.unique_id == "0d1cef00-00dc-1000-9c80-4844f7b172de" + + async def test_update_missing_mac_added_unique_id_preserved_from_zeroconf( hass, remotews: Mock ): @@ -932,12 +977,21 @@ async def test_update_missing_mac_added_unique_id_preserved_from_zeroconf( unique_id="0d1cef00-00dc-1000-9c80-4844f7b172de", ) entry.add_to_hass(hass) - result = await hass.config_entries.flow.async_init( - DOMAIN, - context={"source": config_entries.SOURCE_ZEROCONF}, - data=MOCK_ZEROCONF_DATA, - ) - await hass.async_block_till_done() + with patch( + "homeassistant.components.samsungtv.async_setup", + return_value=True, + ) as mock_setup, patch( + "homeassistant.components.samsungtv.async_setup_entry", + return_value=True, + ) as mock_setup_entry: + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": config_entries.SOURCE_ZEROCONF}, + data=MOCK_ZEROCONF_DATA, + ) + await hass.async_block_till_done() + assert len(mock_setup.mock_calls) == 1 + assert len(mock_setup_entry.mock_calls) == 1 assert result["type"] == "abort" assert result["reason"] == "already_configured" assert entry.data[CONF_MAC] == "aa:bb:cc:dd:ee:ff"