diff --git a/homeassistant/components/bmw_connected_drive/config_flow.py b/homeassistant/components/bmw_connected_drive/config_flow.py index 3468ee25ca1..37ff1eb374c 100644 --- a/homeassistant/components/bmw_connected_drive/config_flow.py +++ b/homeassistant/components/bmw_connected_drive/config_flow.py @@ -21,7 +21,6 @@ from homeassistant.config_entries import ( ) from homeassistant.const import CONF_PASSWORD, CONF_REGION, CONF_SOURCE, CONF_USERNAME from homeassistant.core import HomeAssistant, callback -from homeassistant.data_entry_flow import AbortFlow from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers.selector import SelectSelector, SelectSelectorConfig @@ -75,7 +74,6 @@ class BMWConfigFlow(ConfigFlow, domain=DOMAIN): VERSION = 1 _existing_entry_data: Mapping[str, Any] | None = None - _existing_entry_unique_id: str | None = None async def async_step_user( self, user_input: dict[str, Any] | None = None @@ -85,15 +83,12 @@ class BMWConfigFlow(ConfigFlow, domain=DOMAIN): if user_input is not None: unique_id = f"{user_input[CONF_REGION]}-{user_input[CONF_USERNAME]}" + await self.async_set_unique_id(unique_id) - if self.source not in {SOURCE_REAUTH, SOURCE_RECONFIGURE}: - await self.async_set_unique_id(unique_id) + if self.source in {SOURCE_REAUTH, SOURCE_RECONFIGURE}: + self._abort_if_unique_id_mismatch(reason="account_mismatch") + else: self._abort_if_unique_id_configured() - elif ( - self.source in {SOURCE_REAUTH, SOURCE_RECONFIGURE} - and unique_id != self._existing_entry_unique_id - ): - raise AbortFlow("account_mismatch") info = None try: @@ -135,16 +130,13 @@ class BMWConfigFlow(ConfigFlow, domain=DOMAIN): ) -> ConfigFlowResult: """Handle configuration by re-auth.""" self._existing_entry_data = entry_data - self._existing_entry_unique_id = self._get_reauth_entry().unique_id return await self.async_step_user() async def async_step_reconfigure( self, user_input: dict[str, Any] | None = None ) -> ConfigFlowResult: """Handle a reconfiguration flow initialized by the user.""" - reconfigure_entry = self._get_reconfigure_entry() - self._existing_entry_data = reconfigure_entry.data - self._existing_entry_unique_id = reconfigure_entry.unique_id + self._existing_entry_data = self._get_reconfigure_entry().data return await self.async_step_user() @staticmethod diff --git a/homeassistant/components/spotify/config_flow.py b/homeassistant/components/spotify/config_flow.py index 510f608746e..58342ba368f 100644 --- a/homeassistant/components/spotify/config_flow.py +++ b/homeassistant/components/spotify/config_flow.py @@ -50,11 +50,9 @@ class SpotifyFlowHandler( await self.async_set_unique_id(current_user["id"]) if self.source == SOURCE_REAUTH: - reauth_entry = self._get_reauth_entry() - if reauth_entry.data["id"] != current_user["id"]: - return self.async_abort(reason="reauth_account_mismatch") + self._abort_if_unique_id_mismatch(reason="reauth_account_mismatch") return self.async_update_reload_and_abort( - reauth_entry, title=name, data=data + self._get_reauth_entry(), title=name, data=data ) return self.async_create_entry(title=name, data=data) diff --git a/homeassistant/components/tesla_fleet/config_flow.py b/homeassistant/components/tesla_fleet/config_flow.py index 64b88792387..ca36c6f511b 100644 --- a/homeassistant/components/tesla_fleet/config_flow.py +++ b/homeassistant/components/tesla_fleet/config_flow.py @@ -8,7 +8,7 @@ from typing import Any import jwt -from homeassistant.config_entries import ConfigEntry, ConfigFlowResult +from homeassistant.config_entries import SOURCE_REAUTH, ConfigFlowResult from homeassistant.helpers import config_entry_oauth2_flow from .const import DOMAIN, LOGGER @@ -21,7 +21,6 @@ class OAuth2FlowHandler( """Config flow to handle Tesla Fleet API OAuth2 authentication.""" DOMAIN = DOMAIN - reauth_entry: ConfigEntry | None = None @property def logger(self) -> logging.Logger: @@ -50,32 +49,19 @@ class OAuth2FlowHandler( ) uid = token["sub"] - if not self.reauth_entry: - await self.async_set_unique_id(uid) - self._abort_if_unique_id_configured() - - return self.async_create_entry(title=uid, data=data) - - if self.reauth_entry.unique_id == uid: - self.hass.config_entries.async_update_entry( - self.reauth_entry, - data=data, + await self.async_set_unique_id(uid) + if self.source == SOURCE_REAUTH: + self._abort_if_unique_id_mismatch(reason="reauth_account_mismatch") + return self.async_update_reload_and_abort( + self._get_reauth_entry(), data=data ) - await self.hass.config_entries.async_reload(self.reauth_entry.entry_id) - return self.async_abort(reason="reauth_successful") - - return self.async_abort( - reason="reauth_account_mismatch", - description_placeholders={"title": self.reauth_entry.title}, - ) + self._abort_if_unique_id_configured() + return self.async_create_entry(title=uid, data=data) async def async_step_reauth( self, entry_data: Mapping[str, Any] ) -> ConfigFlowResult: """Perform reauth upon an API authentication error.""" - self.reauth_entry = self.hass.config_entries.async_get_entry( - self.context["entry_id"] - ) return await self.async_step_reauth_confirm() async def async_step_reauth_confirm( diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index 28fecf9bcc4..a7b1b3b8d77 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -2432,6 +2432,26 @@ class ConfigFlow(ConfigEntryBaseFlow): self._async_current_entries(include_ignore=False), match_dict ) + @callback + def _abort_if_unique_id_mismatch( + self, + *, + reason: str = "unique_id_mismatch", + ) -> None: + """Abort if the unique ID does not match the reauth/reconfigure context. + + Requires strings.json entry corresponding to the `reason` parameter + in user visible flows. + """ + if ( + self.source == SOURCE_REAUTH + and self._get_reauth_entry().unique_id != self.unique_id + ) or ( + self.source == SOURCE_RECONFIGURE + and self._get_reconfigure_entry().unique_id != self.unique_id + ): + raise data_entry_flow.AbortFlow(reason) + @callback def _abort_if_unique_id_configured( self, diff --git a/tests/test_config_entries.py b/tests/test_config_entries.py index db78fb2903e..997a6231b58 100644 --- a/tests/test_config_entries.py +++ b/tests/test_config_entries.py @@ -6677,6 +6677,73 @@ async def test_reauth_helper_alignment( assert helper_flow_init_data == reauth_flow_init_data +@pytest.mark.parametrize( + ("original_unique_id", "new_unique_id", "reason"), + [ + ("unique", "unique", "success"), + (None, None, "success"), + ("unique", "new", "unique_id_mismatch"), + ("unique", None, "unique_id_mismatch"), + (None, "new", "unique_id_mismatch"), + ], +) +@pytest.mark.parametrize( + "source", + [config_entries.SOURCE_REAUTH, config_entries.SOURCE_RECONFIGURE], +) +async def test_abort_if_unique_id_mismatch( + hass: HomeAssistant, + source: str, + original_unique_id: str | None, + new_unique_id: str | None, + reason: str, +) -> None: + """Test to check if_unique_id_mismatch behavior.""" + entry = MockConfigEntry( + title="From config flow", + domain="test", + entry_id="01J915Q6T9F6G5V0QJX6HBC94T", + data={"host": "any", "port": 123}, + unique_id=original_unique_id, + ) + entry.add_to_hass(hass) + + mock_setup_entry = AsyncMock(return_value=True) + + mock_integration(hass, MockModule("test", async_setup_entry=mock_setup_entry)) + mock_platform(hass, "test.config_flow", None) + + class TestFlow(config_entries.ConfigFlow): + VERSION = 1 + + async def async_step_user(self, user_input=None): + """Test user step.""" + return await self._async_step_confirm() + + async def async_step_reauth(self, entry_data): + """Test reauth step.""" + return await self._async_step_confirm() + + async def async_step_reconfigure(self, user_input=None): + """Test reauth step.""" + return await self._async_step_confirm() + + async def _async_step_confirm(self): + """Confirm input.""" + await self.async_set_unique_id(new_unique_id) + self._abort_if_unique_id_mismatch() + return self.async_abort(reason="success") + + with mock_config_flow("test", TestFlow): + if source == config_entries.SOURCE_REAUTH: + result = await entry.start_reauth_flow(hass) + elif source == config_entries.SOURCE_RECONFIGURE: + result = await entry.start_reconfigure_flow(hass) + await hass.async_block_till_done() + assert result["type"] is FlowResultType.ABORT + assert result["reason"] == reason + + def test_state_not_stored_in_storage() -> None: """Test that state is not stored in storage.