diff --git a/homeassistant/components/shelly/__init__.py b/homeassistant/components/shelly/__init__.py index 125e63449ef..dcb2f518144 100644 --- a/homeassistant/components/shelly/__init__.py +++ b/homeassistant/components/shelly/__init__.py @@ -4,10 +4,13 @@ from __future__ import annotations import asyncio from collections.abc import Coroutine from datetime import timedelta +from http import HTTPStatus from typing import Any, Final, cast +from aiohttp import ClientResponseError import aioshelly from aioshelly.block_device import BlockDevice +from aioshelly.exceptions import AuthRequired, InvalidAuthError from aioshelly.rpc_device import RpcDevice import async_timeout import voluptuous as vol @@ -22,7 +25,7 @@ from homeassistant.const import ( Platform, ) from homeassistant.core import Event, HomeAssistant, callback -from homeassistant.exceptions import ConfigEntryNotReady +from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady from homeassistant.helpers import aiohttp_client, device_registry, update_coordinator import homeassistant.helpers.config_validation as cv from homeassistant.helpers.debounce import Debouncer @@ -191,12 +194,18 @@ async def async_setup_block_entry(hass: HomeAssistant, entry: ConfigEntry) -> bo try: async with async_timeout.timeout(AIOSHELLY_DEVICE_TIMEOUT_SEC): await device.initialize() + await device.update_status() except asyncio.TimeoutError as err: raise ConfigEntryNotReady( str(err) or "Timeout during device setup" ) from err except OSError as err: raise ConfigEntryNotReady(str(err) or "Error during device setup") from err + except AuthRequired as err: + raise ConfigEntryAuthFailed from err + except ClientResponseError as err: + if err.status == HTTPStatus.UNAUTHORIZED: + raise ConfigEntryAuthFailed from err async_block_device_setup(hass, entry, device) elif sleep_period is None or device_entry is None: @@ -253,6 +262,8 @@ async def async_setup_rpc_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool raise ConfigEntryNotReady(str(err) or "Timeout during device setup") from err except OSError as err: raise ConfigEntryNotReady(str(err) or "Error during device setup") from err + except (AuthRequired, InvalidAuthError) as err: + raise ConfigEntryAuthFailed from err device_wrapper = hass.data[DOMAIN][DATA_CONFIG_ENTRY][entry.entry_id][ RPC diff --git a/homeassistant/components/shelly/climate.py b/homeassistant/components/shelly/climate.py index 0bdcb3a9ad9..f98c048d569 100644 --- a/homeassistant/components/shelly/climate.py +++ b/homeassistant/components/shelly/climate.py @@ -6,6 +6,7 @@ from collections.abc import Mapping from typing import Any, cast from aioshelly.block_device import Block +from aioshelly.exceptions import AuthRequired import async_timeout from homeassistant.components.climate import ( @@ -318,11 +319,14 @@ class BlockSleepingClimate( assert self.block.channel - self._preset_modes = [ - PRESET_NONE, - *self.wrapper.device.settings["thermostats"][int(self.block.channel)][ - "schedule_profile_names" - ], - ] - - self.async_write_ha_state() + try: + self._preset_modes = [ + PRESET_NONE, + *self.wrapper.device.settings["thermostats"][ + int(self.block.channel) + ]["schedule_profile_names"], + ] + except AuthRequired: + self.wrapper.entry.async_start_reauth(self.hass) + else: + self.async_write_ha_state() diff --git a/homeassistant/components/shelly/config_flow.py b/homeassistant/components/shelly/config_flow.py index 41e0bd3031a..38d30fd0b62 100644 --- a/homeassistant/components/shelly/config_flow.py +++ b/homeassistant/components/shelly/config_flow.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio +from collections.abc import Mapping from http import HTTPStatus from typing import Any, Final @@ -91,6 +92,7 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): host: str = "" info: dict[str, Any] = {} device_info: dict[str, Any] = {} + entry: config_entries.ConfigEntry | None = None async def async_step_user( self, user_input: dict[str, Any] | None = None @@ -262,6 +264,53 @@ class ConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): errors=errors, ) + async def async_step_reauth(self, entry_data: Mapping[str, Any]) -> FlowResult: + """Handle configuration by re-auth.""" + self.entry = self.hass.config_entries.async_get_entry(self.context["entry_id"]) + return await self.async_step_reauth_confirm() + + async def async_step_reauth_confirm( + self, user_input: dict[str, Any] | None = None + ) -> FlowResult: + """Dialog that informs the user that reauth is required.""" + errors: dict[str, str] = {} + assert self.entry is not None + host = self.entry.data[CONF_HOST] + + if user_input is not None: + info = await self._async_get_info(host) + if self.entry.data.get("gen", 1) != 1: + user_input[CONF_USERNAME] = "admin" + try: + await validate_input(self.hass, host, info, user_input) + except ( + aiohttp.ClientResponseError, + aioshelly.exceptions.InvalidAuthError, + asyncio.TimeoutError, + aiohttp.ClientError, + ): + return self.async_abort(reason="reauth_unsuccessful") + else: + self.hass.config_entries.async_update_entry( + self.entry, data={**self.entry.data, **user_input} + ) + await self.hass.config_entries.async_reload(self.entry.entry_id) + return self.async_abort(reason="reauth_successful") + + if self.entry.data.get("gen", 1) == 1: + schema = { + vol.Required(CONF_USERNAME): str, + vol.Required(CONF_PASSWORD): str, + } + else: + schema = {vol.Required(CONF_PASSWORD): str} + + return self.async_show_form( + step_id="reauth_confirm", + data_schema=vol.Schema(schema), + errors=errors, + ) + async def _async_get_info(self, host: str) -> dict[str, Any]: """Get info from shelly device.""" async with async_timeout.timeout(AIOSHELLY_DEVICE_TIMEOUT_SEC): diff --git a/homeassistant/components/shelly/strings.json b/homeassistant/components/shelly/strings.json index db1c6043187..d3684f85be2 100644 --- a/homeassistant/components/shelly/strings.json +++ b/homeassistant/components/shelly/strings.json @@ -14,6 +14,12 @@ "password": "[%key:common::config_flow::data::password%]" } }, + "reauth_confirm": { + "data": { + "username": "[%key:common::config_flow::data::username%]", + "password": "[%key:common::config_flow::data::password%]" + } + }, "confirm_discovery": { "description": "Do you want to set up the {model} at {host}?\n\nBattery-powered devices that are password protected must be woken up before continuing with setting up.\nBattery-powered devices that are not password protected will be added when the device wakes up, you can now manually wake the device up using a button on it or wait for the next data update from the device." } @@ -26,7 +32,9 @@ }, "abort": { "already_configured": "[%key:common::config_flow::abort::already_configured_device%]", - "unsupported_firmware": "The device is using an unsupported firmware version." + "unsupported_firmware": "The device is using an unsupported firmware version.", + "reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]", + "reauth_unsuccessful": "Re-authentication was unsuccessful, please remove the integration and set it up again." } }, "device_automation": { diff --git a/tests/components/shelly/test_config_flow.py b/tests/components/shelly/test_config_flow.py index f47fdef0994..b7083cb6805 100644 --- a/tests/components/shelly/test_config_flow.py +++ b/tests/components/shelly/test_config_flow.py @@ -10,6 +10,7 @@ import pytest from homeassistant import config_entries, data_entry_flow from homeassistant.components import zeroconf from homeassistant.components.shelly.const import DOMAIN +from homeassistant.config_entries import SOURCE_REAUTH from tests.common import MockConfigEntry @@ -780,3 +781,107 @@ async def test_zeroconf_require_auth(hass): } assert len(mock_setup.mock_calls) == 1 assert len(mock_setup_entry.mock_calls) == 1 + + +@pytest.mark.parametrize( + "test_data", + [ + (1, {"username": "test user", "password": "test1 password"}), + (2, {"password": "test2 password"}), + ], +) +async def test_reauth_successful(hass, test_data): + """Test starting a reauthentication flow.""" + gen, user_input = test_data + entry = MockConfigEntry( + domain="shelly", unique_id="test-mac", data={"host": "0.0.0.0", "gen": gen} + ) + entry.add_to_hass(hass) + + with patch( + "aioshelly.common.get_info", + return_value={"mac": "test-mac", "type": "SHSW-1", "auth": True, "gen": gen}, + ), patch( + "aioshelly.block_device.BlockDevice.create", + new=AsyncMock( + return_value=Mock( + model="SHSW-1", + settings=MOCK_SETTINGS, + ) + ), + ), patch( + "aioshelly.rpc_device.RpcDevice.create", + new=AsyncMock( + return_value=Mock( + shelly={"model": "SHSW-1", "gen": gen}, + config=MOCK_CONFIG, + shutdown=AsyncMock(), + ) + ), + ): + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": SOURCE_REAUTH, "entry_id": entry.entry_id}, + data=entry.data, + ) + + assert result["type"] == data_entry_flow.FlowResultType.FORM + assert result["step_id"] == "reauth_confirm" + + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + user_input=user_input, + ) + + assert result["type"] == data_entry_flow.FlowResultType.ABORT + assert result["reason"] == "reauth_successful" + + +@pytest.mark.parametrize( + "test_data", + [ + ( + 1, + {"username": "test user", "password": "test1 password"}, + aioshelly.exceptions.InvalidAuthError(code=HTTPStatus.UNAUTHORIZED.value), + ), + ( + 2, + {"password": "test2 password"}, + aiohttp.ClientResponseError(Mock(), (), status=HTTPStatus.UNAUTHORIZED), + ), + ], +) +async def test_reauth_unsuccessful(hass, test_data): + """Test reauthentication flow failed.""" + gen, user_input, exc = test_data + entry = MockConfigEntry( + domain="shelly", unique_id="test-mac", data={"host": "0.0.0.0", "gen": gen} + ) + entry.add_to_hass(hass) + + with patch( + "aioshelly.common.get_info", + return_value={"mac": "test-mac", "type": "SHSW-1", "auth": True, "gen": gen}, + ), patch( + "aioshelly.block_device.BlockDevice.create", + new=AsyncMock(side_effect=exc), + ), patch( + "aioshelly.rpc_device.RpcDevice.create", new=AsyncMock(side_effect=exc) + ): + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": SOURCE_REAUTH, "entry_id": entry.entry_id}, + data=entry.data, + ) + + assert result["type"] == data_entry_flow.FlowResultType.FORM + assert result["step_id"] == "reauth_confirm" + + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + user_input=user_input, + ) + + assert result["type"] == data_entry_flow.FlowResultType.ABORT + assert result["reason"] == "reauth_unsuccessful"