Fix Withings re-authentication flow (#74961)

This commit is contained in:
epenet 2022-07-11 14:27:54 +02:00 committed by GitHub
parent ab9950621b
commit ce353460b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 38 additions and 63 deletions

View File

@ -32,7 +32,7 @@ from homeassistant.components.application_credentials import AuthImplementation
from homeassistant.components.binary_sensor import DOMAIN as BINARY_SENSOR_DOMAIN from homeassistant.components.binary_sensor import DOMAIN as BINARY_SENSOR_DOMAIN
from homeassistant.components.http import HomeAssistantView from homeassistant.components.http import HomeAssistantView
from homeassistant.components.sensor import DOMAIN as SENSOR_DOMAIN from homeassistant.components.sensor import DOMAIN as SENSOR_DOMAIN
from homeassistant.config_entries import SOURCE_REAUTH, ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ( from homeassistant.const import (
CONF_WEBHOOK_ID, CONF_WEBHOOK_ID,
MASS_KILOGRAMS, MASS_KILOGRAMS,
@ -57,6 +57,7 @@ from . import const
from .const import Measurement from .const import Measurement
_LOGGER = logging.getLogger(const.LOG_NAMESPACE) _LOGGER = logging.getLogger(const.LOG_NAMESPACE)
_RETRY_COEFFICIENT = 0.5
NOT_AUTHENTICATED_ERROR = re.compile( NOT_AUTHENTICATED_ERROR = re.compile(
f"^{HTTPStatus.UNAUTHORIZED},.*", f"^{HTTPStatus.UNAUTHORIZED},.*",
re.IGNORECASE, re.IGNORECASE,
@ -484,7 +485,7 @@ class ConfigEntryWithingsApi(AbstractWithingsApi):
) -> None: ) -> None:
"""Initialize object.""" """Initialize object."""
self._hass = hass self._hass = hass
self._config_entry = config_entry self.config_entry = config_entry
self._implementation = implementation self._implementation = implementation
self.session = OAuth2Session(hass, config_entry, implementation) self.session = OAuth2Session(hass, config_entry, implementation)
@ -496,7 +497,7 @@ class ConfigEntryWithingsApi(AbstractWithingsApi):
self.session.async_ensure_token_valid(), self._hass.loop self.session.async_ensure_token_valid(), self._hass.loop
).result() ).result()
access_token = self._config_entry.data["token"]["access_token"] access_token = self.config_entry.data["token"]["access_token"]
response = requests.request( response = requests.request(
method, method,
f"{self.URL}/{path}", f"{self.URL}/{path}",
@ -651,7 +652,7 @@ class DataManager:
"Failed attempt %s of %s (%s)", attempt, attempts, exception1 "Failed attempt %s of %s (%s)", attempt, attempts, exception1
) )
# Make each backoff pause a little bit longer # Make each backoff pause a little bit longer
await asyncio.sleep(0.5 * attempt) await asyncio.sleep(_RETRY_COEFFICIENT * attempt)
exception = exception1 exception = exception1
continue continue
@ -738,32 +739,8 @@ class DataManager:
if isinstance( if isinstance(
exception, (UnauthorizedException, AuthFailedException) exception, (UnauthorizedException, AuthFailedException)
) or NOT_AUTHENTICATED_ERROR.match(str(exception)): ) or NOT_AUTHENTICATED_ERROR.match(str(exception)):
context = { self._api.config_entry.async_start_reauth(self._hass)
const.PROFILE: self._profile, return None
"userid": self._user_id,
"source": SOURCE_REAUTH,
}
# Check if reauth flow already exists.
flow = next(
iter(
flow
for flow in self._hass.config_entries.flow.async_progress_by_handler(
const.DOMAIN
)
if flow.context == context
),
None,
)
if flow:
return
# Start a reauth flow.
await self._hass.config_entries.flow.async_init(
const.DOMAIN,
context=context,
)
return
raise exception raise exception

View File

@ -8,7 +8,6 @@ from typing import Any
import voluptuous as vol import voluptuous as vol
from withings_api.common import AuthScope from withings_api.common import AuthScope
from homeassistant.config_entries import SOURCE_REAUTH
from homeassistant.data_entry_flow import FlowResult from homeassistant.data_entry_flow import FlowResult
from homeassistant.helpers import config_entry_oauth2_flow from homeassistant.helpers import config_entry_oauth2_flow
from homeassistant.util import slugify from homeassistant.util import slugify
@ -25,6 +24,7 @@ class WithingsFlowHandler(
# Temporarily holds authorization data during the profile step. # Temporarily holds authorization data during the profile step.
_current_data: dict[str, None | str | int] = {} _current_data: dict[str, None | str | int] = {}
_reauth_profile: str | None = None
@property @property
def logger(self) -> logging.Logger: def logger(self) -> logging.Logger:
@ -53,12 +53,7 @@ class WithingsFlowHandler(
async def async_step_profile(self, data: dict[str, Any]) -> FlowResult: async def async_step_profile(self, data: dict[str, Any]) -> FlowResult:
"""Prompt the user to select a user profile.""" """Prompt the user to select a user profile."""
errors = {} errors = {}
reauth_profile = ( profile = data.get(const.PROFILE) or self._reauth_profile
self.context.get(const.PROFILE)
if self.context.get("source") == SOURCE_REAUTH
else None
)
profile = data.get(const.PROFILE) or reauth_profile
if profile: if profile:
existing_entries = [ existing_entries = [
@ -67,7 +62,7 @@ class WithingsFlowHandler(
if slugify(config_entry.data.get(const.PROFILE)) == slugify(profile) if slugify(config_entry.data.get(const.PROFILE)) == slugify(profile)
] ]
if reauth_profile or not existing_entries: if self._reauth_profile or not existing_entries:
new_data = {**self._current_data, **data, const.PROFILE: profile} new_data = {**self._current_data, **data, const.PROFILE: profile}
self._current_data = {} self._current_data = {}
return await self.async_step_finish(new_data) return await self.async_step_finish(new_data)
@ -81,16 +76,23 @@ class WithingsFlowHandler(
) )
async def async_step_reauth(self, data: Mapping[str, Any]) -> FlowResult: async def async_step_reauth(self, data: Mapping[str, Any]) -> FlowResult:
"""Prompt user to re-authenticate."""
self._reauth_profile = data.get(const.PROFILE)
return await self.async_step_reauth_confirm()
async def async_step_reauth_confirm(
self, data: dict[str, Any] | None = None
) -> FlowResult:
"""Prompt user to re-authenticate.""" """Prompt user to re-authenticate."""
if data is not None: if data is not None:
return await self.async_step_user() return await self.async_step_user()
placeholders = {const.PROFILE: self.context["profile"]} placeholders = {const.PROFILE: self._reauth_profile}
self.context.update({"title_placeholders": placeholders}) self.context.update({"title_placeholders": placeholders})
return self.async_show_form( return self.async_show_form(
step_id="reauth", step_id="reauth_confirm",
description_placeholders=placeholders, description_placeholders=placeholders,
) )

View File

@ -10,7 +10,7 @@
"pick_implementation": { "pick_implementation": {
"title": "[%key:common::config_flow::title::oauth2_pick_implementation%]" "title": "[%key:common::config_flow::title::oauth2_pick_implementation%]"
}, },
"reauth": { "reauth_confirm": {
"title": "[%key:common::config_flow::title::reauth%]", "title": "[%key:common::config_flow::title::reauth%]",
"description": "The \"{profile}\" profile needs to be re-authenticated in order to continue receiving Withings data." "description": "The \"{profile}\" profile needs to be re-authenticated in order to continue receiving Withings data."
} }

View File

@ -24,7 +24,7 @@
"description": "Provide a unique profile name for this data. Typically this is the name of the profile you selected in the previous step.", "description": "Provide a unique profile name for this data. Typically this is the name of the profile you selected in the previous step.",
"title": "User Profile." "title": "User Profile."
}, },
"reauth": { "reauth_confirm": {
"description": "The \"{profile}\" profile needs to be re-authenticated in order to continue receiving Withings data.", "description": "The \"{profile}\" profile needs to be re-authenticated in order to continue receiving Withings data.",
"title": "Reauthenticate Integration" "title": "Reauthenticate Integration"
} }

View File

@ -42,6 +42,7 @@ from homeassistant.helpers.config_entry_oauth2_flow import AUTH_CALLBACK_PATH
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
from tests.common import MockConfigEntry
from tests.test_util.aiohttp import AiohttpClientMocker from tests.test_util.aiohttp import AiohttpClientMocker
@ -167,6 +168,10 @@ class ComponentFactory:
) )
api_mock: ConfigEntryWithingsApi = MagicMock(spec=ConfigEntryWithingsApi) api_mock: ConfigEntryWithingsApi = MagicMock(spec=ConfigEntryWithingsApi)
api_mock.config_entry = MockConfigEntry(
domain=const.DOMAIN,
data={"profile": profile_config.profile},
)
ComponentFactory._setup_api_method( ComponentFactory._setup_api_method(
api_mock.user_get_device, profile_config.api_response_user_get_device api_mock.user_get_device, profile_config.api_response_user_get_device
) )
@ -301,15 +306,6 @@ def get_config_entries_for_user_id(
) )
def async_get_flow_for_user_id(hass: HomeAssistant, user_id: int) -> list[dict]:
"""Get a flow for a user id."""
return [
flow
for flow in hass.config_entries.flow.async_progress()
if flow["handler"] == const.DOMAIN and flow["context"].get("userid") == user_id
]
def get_data_manager_by_user_id( def get_data_manager_by_user_id(
hass: HomeAssistant, user_id: int hass: HomeAssistant, user_id: int
) -> DataManager | None: ) -> DataManager | None:

View File

@ -62,11 +62,17 @@ async def test_config_reauth_profile(
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(
const.DOMAIN, const.DOMAIN,
context={"source": config_entries.SOURCE_REAUTH, "profile": "person0"}, context={
"source": config_entries.SOURCE_REAUTH,
"entry_id": config_entry.entry_id,
"title_placeholders": {"name": config_entry.title},
"unique_id": config_entry.unique_id,
},
data={"profile": "person0"},
) )
assert result assert result
assert result["type"] == "form" assert result["type"] == "form"
assert result["step_id"] == "reauth" assert result["step_id"] == "reauth_confirm"
assert result["description_placeholders"] == {const.PROFILE: "person0"} assert result["description_placeholders"] == {const.PROFILE: "person0"}
result = await hass.config_entries.flow.async_configure( result = await hass.config_entries.flow.async_configure(

View File

@ -20,12 +20,7 @@ from homeassistant.core import DOMAIN as HA_DOMAIN, HomeAssistant
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator from homeassistant.helpers.update_coordinator import DataUpdateCoordinator
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from .common import ( from .common import ComponentFactory, get_data_manager_by_user_id, new_profile_config
ComponentFactory,
async_get_flow_for_user_id,
get_data_manager_by_user_id,
new_profile_config,
)
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
@ -122,6 +117,7 @@ async def test_async_setup_no_config(hass: HomeAssistant) -> None:
[Exception("401, this is the message")], [Exception("401, this is the message")],
], ],
) )
@patch("homeassistant.components.withings.common._RETRY_COEFFICIENT", 0)
async def test_auth_failure( async def test_auth_failure(
hass: HomeAssistant, hass: HomeAssistant,
component_factory: ComponentFactory, component_factory: ComponentFactory,
@ -138,20 +134,18 @@ async def test_auth_failure(
) )
await component_factory.configure_component(profile_configs=(person0,)) await component_factory.configure_component(profile_configs=(person0,))
assert not async_get_flow_for_user_id(hass, person0.user_id) assert not hass.config_entries.flow.async_progress()
await component_factory.setup_profile(person0.user_id) await component_factory.setup_profile(person0.user_id)
data_manager = get_data_manager_by_user_id(hass, person0.user_id) data_manager = get_data_manager_by_user_id(hass, person0.user_id)
await data_manager.poll_data_update_coordinator.async_refresh() await data_manager.poll_data_update_coordinator.async_refresh()
flows = async_get_flow_for_user_id(hass, person0.user_id) flows = hass.config_entries.flow.async_progress()
assert flows assert flows
assert len(flows) == 1 assert len(flows) == 1
flow = flows[0] flow = flows[0]
assert flow["handler"] == const.DOMAIN assert flow["handler"] == const.DOMAIN
assert flow["context"]["profile"] == person0.profile
assert flow["context"]["userid"] == person0.user_id
result = await hass.config_entries.flow.async_configure( result = await hass.config_entries.flow.async_configure(
flow["flow_id"], user_input={} flow["flow_id"], user_input={}