Fix Shelly reauth flow (#114180)

* Fix Shelly reauth flow

* Rename shutdown_device to async_shutdown_device
This commit is contained in:
Shay Levy 2024-03-25 23:27:44 +02:00 committed by GitHub
parent 9b682388f5
commit 121182167f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 46 additions and 21 deletions

View File

@ -53,6 +53,7 @@ from .coordinator import (
) )
from .utils import ( from .utils import (
async_create_issue_unsupported_firmware, async_create_issue_unsupported_firmware,
async_shutdown_device,
get_block_device_sleep_period, get_block_device_sleep_period,
get_coap_context, get_coap_context,
get_device_entry_gen, get_device_entry_gen,
@ -339,12 +340,8 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
shelly_entry_data = get_entry_data(hass)[entry.entry_id] shelly_entry_data = get_entry_data(hass)[entry.entry_id]
# If device is present, block/rpc coordinator is not setup yet # If device is present, block/rpc coordinator is not setup yet
device = shelly_entry_data.device if (device := shelly_entry_data.device) is not None:
if isinstance(device, RpcDevice): await async_shutdown_device(device)
await device.shutdown()
return True
if isinstance(device, BlockDevice):
device.shutdown()
return True return True
platforms = RPC_SLEEPING_PLATFORMS platforms = RPC_SLEEPING_PLATFORMS

View File

@ -319,7 +319,7 @@ class BlockSleepingClimate(
f" {repr(err)}" f" {repr(err)}"
) from err ) from err
except InvalidAuthError: except InvalidAuthError:
self.coordinator.entry.async_start_reauth(self.hass) await self.coordinator.async_shutdown_device_and_start_reauth()
async def async_set_temperature(self, **kwargs: Any) -> None: async def async_set_temperature(self, **kwargs: Any) -> None:
"""Set new target temperature.""" """Set new target temperature."""
@ -436,7 +436,10 @@ class BlockSleepingClimate(
]["schedule_profile_names"], ]["schedule_profile_names"],
] ]
except InvalidAuthError: except InvalidAuthError:
self.coordinator.entry.async_start_reauth(self.hass) self.hass.async_create_task(
self.coordinator.async_shutdown_device_and_start_reauth(),
eager_start=True,
)
else: else:
self.async_write_ha_state() self.async_write_ha_state()

View File

@ -24,7 +24,13 @@ from homeassistant.config_entries import (
ConfigFlowResult, ConfigFlowResult,
OptionsFlow, OptionsFlow,
) )
from homeassistant.const import CONF_HOST, CONF_PASSWORD, CONF_PORT, CONF_USERNAME from homeassistant.const import (
CONF_HOST,
CONF_MAC,
CONF_PASSWORD,
CONF_PORT,
CONF_USERNAME,
)
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.selector import SelectSelector, SelectSelectorConfig from homeassistant.helpers.selector import SelectSelector, SelectSelectorConfig
@ -84,6 +90,7 @@ async def validate_input(
ip_address=host, ip_address=host,
username=data.get(CONF_USERNAME), username=data.get(CONF_USERNAME),
password=data.get(CONF_PASSWORD), password=data.get(CONF_PASSWORD),
device_mac=info[CONF_MAC],
port=port, port=port,
) )
@ -153,7 +160,7 @@ class ShellyConfigFlow(ConfigFlow, domain=DOMAIN):
LOGGER.exception("Unexpected exception") LOGGER.exception("Unexpected exception")
errors["base"] = "unknown" errors["base"] = "unknown"
else: else:
await self.async_set_unique_id(self.info["mac"]) await self.async_set_unique_id(self.info[CONF_MAC])
self._abort_if_unique_id_configured({CONF_HOST: host}) self._abort_if_unique_id_configured({CONF_HOST: host})
self.host = host self.host = host
self.port = port self.port = port
@ -286,7 +293,7 @@ class ShellyConfigFlow(ConfigFlow, domain=DOMAIN):
if not mac: if not mac:
# We could not get the mac address from the name # We could not get the mac address from the name
# so need to check here since we just got the info # so need to check here since we just got the info
await self._async_discovered_mac(self.info["mac"], host) await self._async_discovered_mac(self.info[CONF_MAC], host)
self.host = host self.host = host
self.context.update( self.context.update(

View File

@ -58,6 +58,7 @@ from .const import (
BLEScannerMode, BLEScannerMode,
) )
from .utils import ( from .utils import (
async_shutdown_device,
get_device_entry_gen, get_device_entry_gen,
get_http_port, get_http_port,
get_rpc_device_wakeup_period, get_rpc_device_wakeup_period,
@ -151,6 +152,14 @@ class ShellyCoordinatorBase(DataUpdateCoordinator[None], Generic[_DeviceT]):
LOGGER.debug("Reloading entry %s", self.name) LOGGER.debug("Reloading entry %s", self.name)
await self.hass.config_entries.async_reload(self.entry.entry_id) await self.hass.config_entries.async_reload(self.entry.entry_id)
async def async_shutdown_device_and_start_reauth(self) -> None:
"""Shutdown Shelly device and start reauth flow."""
# not running disconnect events since we have auth error
# and won't be able to send commands to the device
self.last_update_success = False
await async_shutdown_device(self.device)
self.entry.async_start_reauth(self.hass)
class ShellyBlockCoordinator(ShellyCoordinatorBase[BlockDevice]): class ShellyBlockCoordinator(ShellyCoordinatorBase[BlockDevice]):
"""Coordinator for a Shelly block based device.""" """Coordinator for a Shelly block based device."""
@ -300,7 +309,7 @@ class ShellyBlockCoordinator(ShellyCoordinatorBase[BlockDevice]):
except DeviceConnectionError as err: except DeviceConnectionError as err:
raise UpdateFailed(f"Error fetching data: {repr(err)}") from err raise UpdateFailed(f"Error fetching data: {repr(err)}") from err
except InvalidAuthError: except InvalidAuthError:
self.entry.async_start_reauth(self.hass) await self.async_shutdown_device_and_start_reauth()
@callback @callback
def _async_handle_update( def _async_handle_update(
@ -384,7 +393,7 @@ class ShellyRestCoordinator(ShellyCoordinatorBase[BlockDevice]):
except DeviceConnectionError as err: except DeviceConnectionError as err:
raise UpdateFailed(f"Error fetching data: {repr(err)}") from err raise UpdateFailed(f"Error fetching data: {repr(err)}") from err
except InvalidAuthError: except InvalidAuthError:
self.entry.async_start_reauth(self.hass) await self.async_shutdown_device_and_start_reauth()
else: else:
update_device_fw_info(self.hass, self.device, self.entry) update_device_fw_info(self.hass, self.device, self.entry)
@ -540,7 +549,7 @@ class ShellyRpcCoordinator(ShellyCoordinatorBase[RpcDevice]):
except DeviceConnectionError as err: except DeviceConnectionError as err:
raise UpdateFailed(f"Device disconnected: {repr(err)}") from err raise UpdateFailed(f"Device disconnected: {repr(err)}") from err
except InvalidAuthError: except InvalidAuthError:
self.entry.async_start_reauth(self.hass) await self.async_shutdown_device_and_start_reauth()
async def _async_disconnected(self) -> None: async def _async_disconnected(self) -> None:
"""Handle device disconnected.""" """Handle device disconnected."""
@ -633,7 +642,8 @@ class ShellyRpcCoordinator(ShellyCoordinatorBase[RpcDevice]):
try: try:
await async_stop_scanner(self.device) await async_stop_scanner(self.device)
except InvalidAuthError: except InvalidAuthError:
self.entry.async_start_reauth(self.hass) await self.async_shutdown_device_and_start_reauth()
return
await self.device.shutdown() await self.device.shutdown()
await self._async_disconnected() await self._async_disconnected()
@ -663,7 +673,7 @@ class ShellyRpcPollingCoordinator(ShellyCoordinatorBase[RpcDevice]):
except (DeviceConnectionError, RpcCallError) as err: except (DeviceConnectionError, RpcCallError) as err:
raise UpdateFailed(f"Device disconnected: {repr(err)}") from err raise UpdateFailed(f"Device disconnected: {repr(err)}") from err
except InvalidAuthError: except InvalidAuthError:
self.entry.async_start_reauth(self.hass) await self.async_shutdown_device_and_start_reauth()
def get_block_coordinator_by_device_id( def get_block_coordinator_by_device_id(

View File

@ -344,7 +344,7 @@ class ShellyBlockEntity(CoordinatorEntity[ShellyBlockCoordinator]):
f" {repr(err)}" f" {repr(err)}"
) from err ) from err
except InvalidAuthError: except InvalidAuthError:
self.coordinator.entry.async_start_reauth(self.hass) await self.coordinator.async_shutdown_device_and_start_reauth()
class ShellyRpcEntity(CoordinatorEntity[ShellyRpcCoordinator]): class ShellyRpcEntity(CoordinatorEntity[ShellyRpcCoordinator]):
@ -397,7 +397,7 @@ class ShellyRpcEntity(CoordinatorEntity[ShellyRpcCoordinator]):
f" {params}, error: {repr(err)}" f" {params}, error: {repr(err)}"
) from err ) from err
except InvalidAuthError: except InvalidAuthError:
self.coordinator.entry.async_start_reauth(self.hass) await self.coordinator.async_shutdown_device_and_start_reauth()
class ShellyBlockAttributeEntity(ShellyBlockEntity, Entity): class ShellyBlockAttributeEntity(ShellyBlockEntity, Entity):

View File

@ -126,4 +126,4 @@ class BlockSleepingNumber(ShellySleepingBlockAttributeEntity, RestoreNumber):
f" {repr(err)}" f" {repr(err)}"
) from err ) from err
except InvalidAuthError: except InvalidAuthError:
self.coordinator.entry.async_start_reauth(self.hass) await self.coordinator.async_shutdown_device_and_start_reauth()

View File

@ -200,7 +200,7 @@ class RestUpdateEntity(ShellyRestAttributeEntity, UpdateEntity):
except DeviceConnectionError as err: except DeviceConnectionError as err:
raise HomeAssistantError(f"Error starting OTA update: {repr(err)}") from err raise HomeAssistantError(f"Error starting OTA update: {repr(err)}") from err
except InvalidAuthError: except InvalidAuthError:
self.coordinator.entry.async_start_reauth(self.hass) await self.coordinator.async_shutdown_device_and_start_reauth()
else: else:
LOGGER.debug("Result of OTA update call: %s", result) LOGGER.debug("Result of OTA update call: %s", result)
@ -289,7 +289,7 @@ class RpcUpdateEntity(ShellyRpcAttributeEntity, UpdateEntity):
except RpcCallError as err: except RpcCallError as err:
raise HomeAssistantError(f"OTA update request error: {repr(err)}") from err raise HomeAssistantError(f"OTA update request error: {repr(err)}") from err
except InvalidAuthError: except InvalidAuthError:
self.coordinator.entry.async_start_reauth(self.hass) await self.coordinator.async_shutdown_device_and_start_reauth()
else: else:
self._ota_in_progress = True self._ota_in_progress = True
LOGGER.debug("OTA update call successful") LOGGER.debug("OTA update call successful")

View File

@ -480,3 +480,11 @@ def is_rpc_wifi_stations_disabled(
def get_http_port(data: MappingProxyType[str, Any]) -> int: def get_http_port(data: MappingProxyType[str, Any]) -> int:
"""Get port from config entry data.""" """Get port from config entry data."""
return cast(int, data.get(CONF_PORT, DEFAULT_HTTP_PORT)) return cast(int, data.get(CONF_PORT, DEFAULT_HTTP_PORT))
async def async_shutdown_device(device: BlockDevice | RpcDevice) -> None:
"""Shutdown a Shelly device."""
if isinstance(device, RpcDevice):
await device.shutdown()
if isinstance(device, BlockDevice):
device.shutdown()