Add reauth flow to Youtube (#93670)

This commit is contained in:
Joost Lekkerkerker 2023-05-28 02:29:18 +02:00 committed by GitHub
parent cc12698f26
commit f3037d0b84
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 166 additions and 9 deletions

View File

@ -6,7 +6,7 @@ from aiohttp.client_exceptions import ClientError, ClientResponseError
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import Platform from homeassistant.const import Platform
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady
from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.config_entry_oauth2_flow import ( from homeassistant.helpers.config_entry_oauth2_flow import (
OAuth2Session, OAuth2Session,
@ -29,7 +29,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
await auth.check_and_refresh_token() await auth.check_and_refresh_token()
except ClientResponseError as err: except ClientResponseError as err:
if 400 <= err.status < 500: if 400 <= err.status < 500:
raise ConfigEntryNotReady( raise ConfigEntryAuthFailed(
"OAuth session is not valid, reauth required" "OAuth session is not valid, reauth required"
) from err ) from err
raise ConfigEntryNotReady from err raise ConfigEntryNotReady from err

View File

@ -1,5 +1,6 @@
"""API for YouTube bound to Home Assistant OAuth.""" """API for YouTube bound to Home Assistant OAuth."""
from aiohttp import ClientSession from aiohttp import ClientSession
from google.auth.exceptions import RefreshError
from google.oauth2.credentials import Credentials from google.oauth2.credentials import Credentials
from google.oauth2.utils import OAuthClientAuthHandler from google.oauth2.utils import OAuthClientAuthHandler
from googleapiclient.discovery import Resource, build from googleapiclient.discovery import Resource, build
@ -35,7 +36,11 @@ class AsyncConfigEntryAuth(OAuthClientAuthHandler):
async def get_resource(self) -> Resource: async def get_resource(self) -> Resource:
"""Create executor job to get current resource.""" """Create executor job to get current resource."""
credentials = Credentials(await self.check_and_refresh_token()) try:
credentials = Credentials(await self.check_and_refresh_token())
except RefreshError as ex:
self.oauth_session.config_entry.async_start_reauth(self.oauth_session.hass)
raise ex
return await self.hass.async_add_executor_job(self._get_resource, credentials) return await self.hass.async_add_executor_job(self._get_resource, credentials)
def _get_resource(self, credentials: Credentials) -> Resource: def _get_resource(self, credentials: Credentials) -> Resource:

View File

@ -1,6 +1,7 @@
"""Config flow for YouTube integration.""" """Config flow for YouTube integration."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Mapping
import logging import logging
from typing import Any from typing import Any
@ -10,6 +11,7 @@ from googleapiclient.errors import HttpError
from googleapiclient.http import HttpRequest from googleapiclient.http import HttpRequest
import voluptuous as vol import voluptuous as vol
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_ACCESS_TOKEN, CONF_TOKEN from homeassistant.const import CONF_ACCESS_TOKEN, CONF_TOKEN
from homeassistant.data_entry_flow import FlowResult from homeassistant.data_entry_flow import FlowResult
from homeassistant.helpers import config_entry_oauth2_flow from homeassistant.helpers import config_entry_oauth2_flow
@ -32,6 +34,8 @@ class OAuth2FlowHandler(
DOMAIN = DOMAIN DOMAIN = DOMAIN
reauth_entry: ConfigEntry | None = None
@property @property
def logger(self) -> logging.Logger: def logger(self) -> logging.Logger:
"""Return logger.""" """Return logger."""
@ -47,6 +51,21 @@ class OAuth2FlowHandler(
"prompt": "consent", "prompt": "consent",
} }
async def async_step_reauth(self, entry_data: Mapping[str, Any]) -> FlowResult:
"""Perform reauth upon an API authentication error."""
self.reauth_entry = self.hass.config_entries.async_get_entry(
self.context["entry_id"]
)
return await self.async_step_reauth_confirm()
async def async_step_reauth_confirm(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Confirm reauth dialog."""
if user_input is None:
return self.async_show_form(step_id="reauth_confirm")
return await self.async_step_user()
async def async_oauth_create_entry(self, data: dict[str, Any]) -> FlowResult: async def async_oauth_create_entry(self, data: dict[str, Any]) -> FlowResult:
"""Create an entry for the flow, or update existing entry.""" """Create an entry for the flow, or update existing entry."""
try: try:
@ -71,10 +90,21 @@ class OAuth2FlowHandler(
self._title = own_channel["snippet"]["title"] self._title = own_channel["snippet"]["title"]
self._data = data self._data = data
await self.async_set_unique_id(own_channel["id"]) if not self.reauth_entry:
self._abort_if_unique_id_configured() await self.async_set_unique_id(own_channel["id"])
self._abort_if_unique_id_configured()
return await self.async_step_channels() return await self.async_step_channels()
if self.reauth_entry.unique_id == own_channel["id"]:
self.hass.config_entries.async_update_entry(self.reauth_entry, data=data)
await self.hass.config_entries.async_reload(self.reauth_entry.entry_id)
return self.async_abort(reason="reauth_successful")
return self.async_abort(
reason="wrong_account",
description_placeholders={"title": self._title},
)
async def async_step_channels( async def async_step_channels(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None

View File

@ -3,6 +3,7 @@ from unittest.mock import patch
from googleapiclient.errors import HttpError from googleapiclient.errors import HttpError
from httplib2 import Response from httplib2 import Response
import pytest
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.components.youtube.const import CONF_CHANNELS, DOMAIN from homeassistant.components.youtube.const import CONF_CHANNELS, DOMAIN
@ -11,8 +12,10 @@ from homeassistant.data_entry_flow import FlowResultType
from homeassistant.helpers import config_entry_oauth2_flow from homeassistant.helpers import config_entry_oauth2_flow
from . import MockService from . import MockService
from .conftest import CLIENT_ID, GOOGLE_AUTH_URI, SCOPES, TITLE from .conftest import CLIENT_ID, GOOGLE_AUTH_URI, GOOGLE_TOKEN_URI, SCOPES, TITLE
from tests.common import MockConfigEntry, load_fixture
from tests.test_util.aiohttp import AiohttpClientMocker
from tests.typing import ClientSessionGenerator from tests.typing import ClientSessionGenerator
@ -134,6 +137,101 @@ async def test_flow_http_error(
) )
@pytest.mark.parametrize(
("fixture", "abort_reason", "placeholders", "calls", "access_token"),
[
("get_channel", "reauth_successful", None, 1, "updated-access-token"),
(
"get_channel_2",
"wrong_account",
{"title": "Linus Tech Tips"},
0,
"mock-access-token",
),
],
)
async def test_reauth(
hass: HomeAssistant,
hass_client_no_auth,
aioclient_mock: AiohttpClientMocker,
current_request_with_host,
config_entry: MockConfigEntry,
fixture: str,
abort_reason: str,
placeholders: dict[str, str],
calls: int,
access_token: str,
) -> None:
"""Test the re-authentication case updates the correct config entry.
Make sure we abort if the user selects the
wrong account on the consent screen.
"""
config_entry.add_to_hass(hass)
config_entry.async_start_reauth(hass)
await hass.async_block_till_done()
flows = hass.config_entries.flow.async_progress()
assert len(flows) == 1
result = flows[0]
assert result["step_id"] == "reauth_confirm"
result = await hass.config_entries.flow.async_configure(result["flow_id"], {})
state = config_entry_oauth2_flow._encode_jwt(
hass,
{
"flow_id": result["flow_id"],
"redirect_uri": "https://example.com/auth/external/callback",
},
)
assert result["url"] == (
f"{GOOGLE_AUTH_URI}?response_type=code&client_id={CLIENT_ID}"
"&redirect_uri=https://example.com/auth/external/callback"
f"&state={state}&scope={'+'.join(SCOPES)}"
"&access_type=offline&prompt=consent"
)
client = await hass_client_no_auth()
resp = await client.get(f"/auth/external/callback?code=abcd&state={state}")
assert resp.status == 200
assert resp.headers["content-type"] == "text/html; charset=utf-8"
aioclient_mock.clear_requests()
aioclient_mock.post(
GOOGLE_TOKEN_URI,
json={
"refresh_token": "mock-refresh-token",
"access_token": "updated-access-token",
"type": "Bearer",
"expires_in": 60,
},
)
with patch(
"homeassistant.components.youtube.async_setup_entry", return_value=True
) as mock_setup, patch(
"httplib2.Http.request",
return_value=(
Response({}),
bytes(load_fixture(f"youtube/{fixture}.json"), encoding="UTF-8"),
),
):
result = await hass.config_entries.flow.async_configure(result["flow_id"])
assert len(hass.config_entries.async_entries(DOMAIN)) == 1
assert result["type"] == "abort"
assert result["reason"] == abort_reason
assert result["description_placeholders"] == placeholders
assert len(mock_setup.mock_calls) == calls
assert config_entry.unique_id == "UC_x5XG1OV2P6uZZ5FSM9Ttw"
assert "token" in config_entry.data
# Verify access token is refreshed
assert config_entry.data["token"]["access_token"] == access_token
assert config_entry.data["token"]["refresh_token"] == "mock-refresh-token"
async def test_flow_exception( async def test_flow_exception(
hass: HomeAssistant, hass: HomeAssistant,
hass_client_no_auth: ClientSessionGenerator, hass_client_no_auth: ClientSessionGenerator,

View File

@ -67,7 +67,7 @@ async def test_expired_token_refresh_success(
( (
time.time() - 3600, time.time() - 3600,
http.HTTPStatus.UNAUTHORIZED, http.HTTPStatus.UNAUTHORIZED,
ConfigEntryState.SETUP_RETRY, ConfigEntryState.SETUP_ERROR,
), ),
( (
time.time() - 3600, time.time() - 3600,

View File

@ -1,9 +1,13 @@
"""Sensor tests for the YouTube integration.""" """Sensor tests for the YouTube integration."""
from unittest.mock import patch
from google.auth.exceptions import RefreshError
from homeassistant import config_entries
from homeassistant.components.youtube import COORDINATOR, DOMAIN
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from .conftest import ComponentSetup from .conftest import TOKEN, ComponentSetup
async def test_sensor(hass: HomeAssistant, setup_integration: ComponentSetup) -> None: async def test_sensor(hass: HomeAssistant, setup_integration: ComponentSetup) -> None:
@ -27,3 +31,23 @@ async def test_sensor(hass: HomeAssistant, setup_integration: ComponentSetup) ->
state.attributes["entity_picture"] state.attributes["entity_picture"]
== "https://yt3.ggpht.com/fca_HuJ99xUxflWdex0XViC3NfctBFreIl8y4i9z411asnGTWY-Ql3MeH_ybA4kNaOjY7kyA=s800-c-k-c0x00ffffff-no-rj" == "https://yt3.ggpht.com/fca_HuJ99xUxflWdex0XViC3NfctBFreIl8y4i9z411asnGTWY-Ql3MeH_ybA4kNaOjY7kyA=s800-c-k-c0x00ffffff-no-rj"
) )
async def test_sensor_reauth_trigger(
hass: HomeAssistant, setup_integration: ComponentSetup
) -> None:
"""Test reauth is triggered after a refresh error."""
await setup_integration()
with patch(TOKEN, side_effect=RefreshError):
entry = hass.config_entries.async_entries(DOMAIN)[0]
await hass.data[DOMAIN][entry.entry_id][COORDINATOR].async_refresh()
await hass.async_block_till_done()
flows = hass.config_entries.flow.async_progress()
assert len(flows) == 1
flow = flows[0]
assert flow["step_id"] == "reauth_confirm"
assert flow["handler"] == DOMAIN
assert flow["context"]["source"] == config_entries.SOURCE_REAUTH