diff --git a/homeassistant/components/watttime/__init__.py b/homeassistant/components/watttime/__init__.py index 8b3a83aa8d1..a4e1acb4b7e 100644 --- a/homeassistant/components/watttime/__init__.py +++ b/homeassistant/components/watttime/__init__.py @@ -5,7 +5,7 @@ from datetime import timedelta from aiowatttime import Client from aiowatttime.emissions import RealTimeEmissionsResponseType -from aiowatttime.errors import WattTimeError +from aiowatttime.errors import InvalidCredentialsError, WattTimeError from homeassistant.config_entries import ConfigEntry from homeassistant.const import ( @@ -15,6 +15,7 @@ from homeassistant.const import ( CONF_USERNAME, ) from homeassistant.core import HomeAssistant +from homeassistant.exceptions import ConfigEntryAuthFailed from homeassistant.helpers import aiohttp_client from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed @@ -36,6 +37,8 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: client = await Client.async_login( entry.data[CONF_USERNAME], entry.data[CONF_PASSWORD], session=session ) + except InvalidCredentialsError as err: + raise ConfigEntryAuthFailed("Invalid username/password") from err except WattTimeError as err: LOGGER.error("Error while authenticating with WattTime: %s", err) return False @@ -46,6 +49,8 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: return await client.emissions.async_get_realtime_emissions( entry.data[CONF_LATITUDE], entry.data[CONF_LONGITUDE] ) + except InvalidCredentialsError as err: + raise ConfigEntryAuthFailed("Invalid username/password") from err except WattTimeError as err: raise UpdateFailed( f"Error while requesting data from WattTime: {err}" diff --git a/homeassistant/components/watttime/config_flow.py b/homeassistant/components/watttime/config_flow.py index 6c523f64331..c2db7847b56 100644 --- a/homeassistant/components/watttime/config_flow.py +++ b/homeassistant/components/watttime/config_flow.py @@ -14,8 +14,10 @@ from homeassistant.const import ( CONF_PASSWORD, CONF_USERNAME, ) +from homeassistant.core import callback from homeassistant.data_entry_flow import FlowResult from homeassistant.helpers import aiohttp_client, config_validation as cv +from homeassistant.helpers.typing import ConfigType from .const import ( CONF_BALANCING_AUTHORITY, @@ -44,6 +46,12 @@ STEP_LOCATION_DATA_SCHEMA = vol.Schema( } ) +STEP_REAUTH_CONFIRM_DATA_SCHEMA = vol.Schema( + { + vol.Required(CONF_PASSWORD): str, + } +) + STEP_USER_DATA_SCHEMA = vol.Schema( { vol.Required(CONF_USERNAME): str, @@ -52,6 +60,12 @@ STEP_USER_DATA_SCHEMA = vol.Schema( ) +@callback +def get_unique_id(data: dict[str, Any]) -> str: + """Get a unique ID from a data payload.""" + return f"{data[CONF_LATITUDE]}, {data[CONF_LONGITUDE]}" + + class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): """Handle a config flow for WattTime.""" @@ -60,8 +74,49 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): def __init__(self) -> None: """Initialize.""" self._client: Client | None = None - self._password: str | None = None - self._username: str | None = None + self._data: dict[str, Any] = {} + + async def _async_validate_credentials( + self, username: str, password: str, error_step_id: str, error_schema: vol.Schema + ): + """Validate input credentials and proceed accordingly.""" + session = aiohttp_client.async_get_clientsession(self.hass) + + try: + self._client = await Client.async_login(username, password, session=session) + except InvalidCredentialsError: + return self.async_show_form( + step_id=error_step_id, + data_schema=error_schema, + errors={"base": "invalid_auth"}, + description_placeholders={CONF_USERNAME: username}, + ) + except Exception as err: # pylint: disable=broad-except + LOGGER.exception("Unexpected exception while logging in: %s", err) + return self.async_show_form( + step_id=error_step_id, + data_schema=error_schema, + errors={"base": "unknown"}, + description_placeholders={CONF_USERNAME: username}, + ) + + if CONF_LATITUDE in self._data: + # If coordinates already exist at this stage, we're in an existing flow and + # should reauth: + entry_unique_id = get_unique_id(self._data) + if existing_entry := await self.async_set_unique_id(entry_unique_id): + self.hass.config_entries.async_update_entry( + existing_entry, data=self._data + ) + self.hass.async_create_task( + self.hass.config_entries.async_reload(existing_entry.entry_id) + ) + return self.async_abort(reason="reauth_successful") + + # ...otherwise, we're in a new flow: + self._data[CONF_USERNAME] = username + self._data[CONF_PASSWORD] = password + return await self.async_step_location() async def async_step_coordinates( self, user_input: dict[str, Any] | None = None @@ -75,7 +130,7 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): if TYPE_CHECKING: assert self._client - unique_id = f"{user_input[CONF_LATITUDE]}, {user_input[CONF_LONGITUDE]}" + unique_id = get_unique_id(user_input) await self.async_set_unique_id(unique_id) self._abort_if_unique_id_configured() @@ -100,8 +155,8 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): return self.async_create_entry( title=unique_id, data={ - CONF_USERNAME: self._username, - CONF_PASSWORD: self._password, + CONF_USERNAME: self._data[CONF_USERNAME], + CONF_PASSWORD: self._data[CONF_PASSWORD], CONF_LATITUDE: user_input[CONF_LATITUDE], CONF_LONGITUDE: user_input[CONF_LONGITUDE], CONF_BALANCING_AUTHORITY: grid_region["name"], @@ -127,6 +182,31 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): ) return await self.async_step_coordinates() + async def async_step_reauth(self, config: ConfigType) -> FlowResult: + """Handle configuration by re-auth.""" + self._data = {**config} + return await self.async_step_reauth_confirm() + + async def async_step_reauth_confirm( + self, user_input: dict[str, Any] | None = None + ) -> FlowResult: + """Handle re-auth completion.""" + if not user_input: + return self.async_show_form( + step_id="reauth_confirm", + data_schema=STEP_REAUTH_CONFIRM_DATA_SCHEMA, + description_placeholders={CONF_USERNAME: self._data[CONF_USERNAME]}, + ) + + self._data[CONF_PASSWORD] = user_input[CONF_PASSWORD] + + return await self._async_validate_credentials( + self._data[CONF_USERNAME], + self._data[CONF_PASSWORD], + "reauth_confirm", + STEP_REAUTH_CONFIRM_DATA_SCHEMA, + ) + async def async_step_user( self, user_input: dict[str, Any] | None = None ) -> FlowResult: @@ -136,28 +216,9 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): step_id="user", data_schema=STEP_USER_DATA_SCHEMA ) - session = aiohttp_client.async_get_clientsession(self.hass) - - try: - self._client = await Client.async_login( - user_input[CONF_USERNAME], - user_input[CONF_PASSWORD], - session=session, - ) - except InvalidCredentialsError: - return self.async_show_form( - step_id="user", - data_schema=STEP_USER_DATA_SCHEMA, - errors={CONF_USERNAME: "invalid_auth"}, - ) - except Exception as err: # pylint: disable=broad-except - LOGGER.exception("Unexpected exception while logging in: %s", err) - return self.async_show_form( - step_id="user", - data_schema=STEP_USER_DATA_SCHEMA, - errors={"base": "unknown"}, - ) - - self._username = user_input[CONF_USERNAME] - self._password = user_input[CONF_PASSWORD] - return await self.async_step_location() + return await self._async_validate_credentials( + user_input[CONF_USERNAME], + user_input[CONF_PASSWORD], + "user", + STEP_USER_DATA_SCHEMA, + ) diff --git a/homeassistant/components/watttime/strings.json b/homeassistant/components/watttime/strings.json index 34dc253dcde..594848afce1 100644 --- a/homeassistant/components/watttime/strings.json +++ b/homeassistant/components/watttime/strings.json @@ -14,6 +14,13 @@ "location_type": "[%key:common::config_flow::data::location%]" } }, + "reauth_confirm": { + "title": "[%key:common::config_flow::title::reauth%]", + "description": "Please re-enter the password for {username}:", + "data": { + "password": "[%key:common::config_flow::data::password%]" + } + }, "user": { "description": "Input your username and password:", "data": { @@ -28,7 +35,8 @@ "unknown_coordinates": "No data for latitude/longitude" }, "abort": { - "already_configured": "[%key:common::config_flow::abort::already_configured_device%]" + "already_configured": "[%key:common::config_flow::abort::already_configured_device%]", + "reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]" } } } diff --git a/homeassistant/components/watttime/translations/en.json b/homeassistant/components/watttime/translations/en.json index 44ae51fae53..e99af749031 100644 --- a/homeassistant/components/watttime/translations/en.json +++ b/homeassistant/components/watttime/translations/en.json @@ -1,7 +1,8 @@ { "config": { "abort": { - "already_configured": "Device is already configured" + "already_configured": "Device is already configured", + "reauth_successful": "Re-authentication was successful" }, "error": { "invalid_auth": "Invalid authentication", @@ -22,6 +23,13 @@ }, "description": "Pick a location to monitor:" }, + "reauth_confirm": { + "data": { + "password": "Password" + }, + "description": "Please re-enter the password for {username}.", + "title": "Reauthenticate Integration" + }, "user": { "data": { "password": "Password", diff --git a/tests/components/watttime/test_config_flow.py b/tests/components/watttime/test_config_flow.py index e50e89dfb26..249bd51c4da 100644 --- a/tests/components/watttime/test_config_flow.py +++ b/tests/components/watttime/test_config_flow.py @@ -42,9 +42,11 @@ def client_fixture(get_grid_region): @pytest.fixture(name="client_login") def client_login_fixture(client): """Define a fixture for patching the aiowatttime coroutine to get a client.""" - with patch("homeassistant.components.watttime.config_flow.Client.async_login") as m: - m.return_value = client - yield m + with patch( + "homeassistant.components.watttime.config_flow.Client.async_login" + ) as mock_client: + mock_client.return_value = client + yield mock_client @pytest.fixture(name="get_grid_region") @@ -162,7 +164,92 @@ async def test_step_coordinates_unknown_error( assert result["errors"] == {"base": "unknown"} -async def test_step_login_coordinates(hass: HomeAssistant, client_login) -> None: +async def test_step_reauth(hass: HomeAssistant, client_login) -> None: + """Test a full reauth flow.""" + MockConfigEntry( + domain=DOMAIN, + unique_id="51.528308, -0.3817765", + data={ + CONF_USERNAME: "user", + CONF_PASSWORD: "password", + CONF_LATITUDE: 51.528308, + CONF_LONGITUDE: -0.3817765, + CONF_BALANCING_AUTHORITY: "Authority 1", + CONF_BALANCING_AUTHORITY_ABBREV: "AUTH_1", + }, + ).add_to_hass(hass) + + await setup.async_setup_component(hass, "persistent_notification", {}) + with patch( + "homeassistant.components.watttime.async_setup_entry", + return_value=True, + ): + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": config_entries.SOURCE_REAUTH}, + data={ + CONF_USERNAME: "user", + CONF_PASSWORD: "password", + CONF_LATITUDE: 51.528308, + CONF_LONGITUDE: -0.3817765, + CONF_BALANCING_AUTHORITY: "Authority 1", + CONF_BALANCING_AUTHORITY_ABBREV: "AUTH_1", + }, + ) + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + user_input={CONF_PASSWORD: "password"}, + ) + await hass.async_block_till_done() + + assert result["type"] == RESULT_TYPE_ABORT + assert result["reason"] == "reauth_successful" + assert len(hass.config_entries.async_entries()) == 1 + + +async def test_step_reauth_invalid_credentials(hass: HomeAssistant) -> None: + """Test that invalid credentials during reauth are handled.""" + MockConfigEntry( + domain=DOMAIN, + unique_id="51.528308, -0.3817765", + data={ + CONF_USERNAME: "user", + CONF_PASSWORD: "password", + CONF_LATITUDE: 51.528308, + CONF_LONGITUDE: -0.3817765, + CONF_BALANCING_AUTHORITY: "Authority 1", + CONF_BALANCING_AUTHORITY_ABBREV: "AUTH_1", + }, + ).add_to_hass(hass) + + await setup.async_setup_component(hass, "persistent_notification", {}) + with patch( + "homeassistant.components.watttime.config_flow.Client.async_login", + AsyncMock(side_effect=InvalidCredentialsError), + ): + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": config_entries.SOURCE_REAUTH}, + data={ + CONF_USERNAME: "user", + CONF_PASSWORD: "password", + CONF_LATITUDE: 51.528308, + CONF_LONGITUDE: -0.3817765, + CONF_BALANCING_AUTHORITY: "Authority 1", + CONF_BALANCING_AUTHORITY_ABBREV: "AUTH_1", + }, + ) + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + user_input={CONF_PASSWORD: "password"}, + ) + await hass.async_block_till_done() + + assert result["type"] == RESULT_TYPE_FORM + assert result["errors"] == {"base": "invalid_auth"} + + +async def test_step_user_coordinates(hass: HomeAssistant, client_login) -> None: """Test a full login flow (inputting custom coordinates).""" with patch( @@ -241,7 +328,7 @@ async def test_step_user_invalid_credentials(hass: HomeAssistant) -> None: await hass.async_block_till_done() assert result["type"] == RESULT_TYPE_FORM - assert result["errors"] == {"username": "invalid_auth"} + assert result["errors"] == {"base": "invalid_auth"} @pytest.mark.parametrize("get_grid_region", [AsyncMock(side_effect=Exception)])