Simplify nest reauth config flow (#63058)

This commit is contained in:
Allen Porter 2022-01-02 12:43:50 -08:00 committed by GitHub
parent 76a7149a5e
commit 9e3f7d2961
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 15 additions and 36 deletions

View File

@ -121,9 +121,7 @@ class NestFlowHandler(
def __init__(self) -> None: def __init__(self) -> None:
"""Initialize NestFlowHandler.""" """Initialize NestFlowHandler."""
super().__init__() super().__init__()
# Allows updating an existing config entry self._reauth = False
self._reauth_data: dict[str, Any] | None = None
# ConfigEntry data for SDM API
self._data: dict[str, Any] = {DATA_SDM: {}} self._data: dict[str, Any] = {DATA_SDM: {}}
@classmethod @classmethod
@ -169,7 +167,7 @@ class NestFlowHandler(
if user_input is None: if user_input is None:
_LOGGER.error("Reauth invoked with empty config entry data") _LOGGER.error("Reauth invoked with empty config entry data")
return self.async_abort(reason="missing_configuration") return self.async_abort(reason="missing_configuration")
self._reauth_data = user_input self._reauth = True
self._data.update(user_input) self._data.update(user_input)
return await self.async_step_reauth_confirm() return await self.async_step_reauth_confirm()
@ -199,7 +197,7 @@ class NestFlowHandler(
"""Handle a flow initialized by the user.""" """Handle a flow initialized by the user."""
if self.is_sdm_api(): if self.is_sdm_api():
# Reauth will update an existing entry # Reauth will update an existing entry
if self._async_current_entries() and not self._reauth_data: if self._async_current_entries() and not self._reauth:
return self.async_abort(reason="single_instance_allowed") return self.async_abort(reason="single_instance_allowed")
return await super().async_step_user(user_input) return await super().async_step_user(user_input)
return await self.async_step_init(user_input) return await self.async_step_init(user_input)
@ -233,9 +231,9 @@ class NestFlowHandler(
def _configure_pubsub(self) -> bool: def _configure_pubsub(self) -> bool:
"""Return True if the config flow should configure Pub/Sub.""" """Return True if the config flow should configure Pub/Sub."""
if self._reauth_data is not None and CONF_SUBSCRIBER_ID in self._reauth_data: if self._reauth:
# Existing entry needs to be reconfigured # Just refreshing tokens and preserving existing subscriber id
return True return False
if CONF_SUBSCRIBER_ID in self.hass.data[DOMAIN][DATA_NEST_CONFIG]: if CONF_SUBSCRIBER_ID in self.hass.data[DOMAIN][DATA_NEST_CONFIG]:
# Hard coded configuration.yaml skips pubsub in config flow # Hard coded configuration.yaml skips pubsub in config flow
return False return False
@ -249,8 +247,8 @@ class NestFlowHandler(
# Populate data from the previous config entry during reauth, then # Populate data from the previous config entry during reauth, then
# overwrite with the user entered values. # overwrite with the user entered values.
data = {} data = {}
if self._reauth_data: if self._reauth:
data.update(self._reauth_data) data.update(self._data)
if user_input: if user_input:
data.update(user_input) data.update(user_input)
cloud_project_id = data.get(CONF_CLOUD_PROJECT_ID, "") cloud_project_id = data.get(CONF_CLOUD_PROJECT_ID, "")

View File

@ -80,9 +80,7 @@ class OAuthFixture:
assert result["type"] == "form" assert result["type"] == "form"
assert result["step_id"] == "pick_implementation" assert result["step_id"] == "pick_implementation"
return await self.hass.config_entries.flow.async_configure( return await self.async_configure(result, {"implementation": auth_domain})
result["flow_id"], {"implementation": auth_domain}
)
async def async_oauth_web_flow(self, result: dict) -> None: async def async_oauth_web_flow(self, result: dict) -> None:
"""Invoke the oauth flow for Web Auth with fake responses.""" """Invoke the oauth flow for Web Auth with fake responses."""
@ -169,9 +167,7 @@ class OAuthFixture:
with patch( with patch(
"homeassistant.components.nest.async_setup_entry", return_value=True "homeassistant.components.nest.async_setup_entry", return_value=True
) as mock_setup: ) as mock_setup:
await self.hass.config_entries.flow.async_configure( await self.async_configure(result, user_input)
result["flow_id"], user_input
)
assert len(mock_setup.mock_calls) == 1 assert len(mock_setup.mock_calls) == 1
await self.hass.async_block_till_done() await self.hass.async_block_till_done()
return self.get_config_entry() return self.get_config_entry()
@ -542,7 +538,7 @@ async def test_pubsub_subscriber_config_entry_reauth(hass, oauth, subscriber):
hass, hass,
{ {
"auth_implementation": APP_AUTH_DOMAIN, "auth_implementation": APP_AUTH_DOMAIN,
"subscription_id": SUBSCRIBER_ID, "subscriber_id": SUBSCRIBER_ID,
"cloud_project_id": CLOUD_PROJECT_ID, "cloud_project_id": CLOUD_PROJECT_ID,
"token": { "token": {
"access_token": "some-revoked-token", "access_token": "some-revoked-token",
@ -552,22 +548,9 @@ async def test_pubsub_subscriber_config_entry_reauth(hass, oauth, subscriber):
) )
result = await oauth.async_reauth(old_entry.data) result = await oauth.async_reauth(old_entry.data)
await oauth.async_oauth_app_flow(result) await oauth.async_oauth_app_flow(result)
result = await oauth.async_configure(result, {"code": "1234"})
# Configure Pub/Sub # Entering an updated access token refreshs the config entry.
await oauth.async_pubsub_flow(result, cloud_project_id=CLOUD_PROJECT_ID) entry = await oauth.async_finish_setup(result, {"code": "1234"})
# Verify existing tokens are replaced
with patch(
"homeassistant.components.nest.api.GoogleNestSubscriber",
return_value=subscriber,
):
entry = await oauth.async_finish_setup(
result, {"cloud_project_id": "other-cloud-project-id"}
)
await hass.async_block_till_done()
entry = oauth.get_config_entry()
entry.data["token"].pop("expires_at") entry.data["token"].pop("expires_at")
assert entry.unique_id == DOMAIN assert entry.unique_id == DOMAIN
assert entry.data["token"] == { assert entry.data["token"] == {
@ -577,7 +560,5 @@ async def test_pubsub_subscriber_config_entry_reauth(hass, oauth, subscriber):
"expires_in": 60, "expires_in": 60,
} }
assert entry.data["auth_implementation"] == APP_AUTH_DOMAIN assert entry.data["auth_implementation"] == APP_AUTH_DOMAIN
assert ( assert entry.data["subscriber_id"] == SUBSCRIBER_ID
"projects/other-cloud-project-id/subscriptions" in entry.data["subscriber_id"] assert entry.data["cloud_project_id"] == CLOUD_PROJECT_ID
)
assert entry.data["cloud_project_id"] == "other-cloud-project-id"