Add reauth flow to webOS TV integration (#86168)

* Add reauth flow to webOS TV integration

* Remove unnecessary else
This commit is contained in:
Shay Levy 2023-01-18 18:48:38 +02:00 committed by GitHub
parent f2b348dbdf
commit c40c37e9ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 217 additions and 37 deletions

View File

@ -18,6 +18,7 @@ from homeassistant.const import (
EVENT_HOMEASSISTANT_STOP, EVENT_HOMEASSISTANT_STOP,
) )
from homeassistant.core import Event, HomeAssistant, ServiceCall from homeassistant.core import Event, HomeAssistant, ServiceCall
from homeassistant.exceptions import ConfigEntryAuthFailed
from homeassistant.helpers import config_validation as cv, discovery from homeassistant.helpers import config_validation as cv, discovery
from homeassistant.helpers.dispatcher import async_dispatcher_send from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
@ -77,8 +78,15 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
# Attempt a connection, but fail gracefully if tv is off for example. # Attempt a connection, but fail gracefully if tv is off for example.
client = WebOsClient(host, key) client = WebOsClient(host, key)
with suppress(*WEBOSTV_EXCEPTIONS, WebOsTvPairError): with suppress(*WEBOSTV_EXCEPTIONS):
try:
await client.connect() await client.connect()
except WebOsTvPairError as err:
raise ConfigEntryAuthFailed(err) from err
# If pairing request accepted there will be no error
# Update the stored key without triggering reauth
update_client_key(hass, entry, client)
async def async_service_handler(service: ServiceCall) -> None: async def async_service_handler(service: ServiceCall) -> None:
method = SERVICE_TO_METHOD[service.service] method = SERVICE_TO_METHOD[service.service]
@ -141,6 +149,19 @@ async def async_control_connect(host: str, key: str | None) -> WebOsClient:
return client return client
def update_client_key(
hass: HomeAssistant, entry: ConfigEntry, client: WebOsClient
) -> None:
"""Check and update stored client key if key has changed."""
host = entry.data[CONF_HOST]
key = entry.data[CONF_CLIENT_SECRET]
if client.client_key != key:
_LOGGER.debug("Updating client key for host %s", host)
data = {CONF_HOST: host, CONF_CLIENT_SECRET: client.client_key}
hass.config_entries.async_update_entry(entry, data=data)
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a config entry.""" """Unload a config entry."""
unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS) unload_ok = await hass.config_entries.async_unload_platforms(entry, PLATFORMS)

View File

@ -1,6 +1,7 @@
"""Config flow to configure webostv component.""" """Config flow to configure webostv component."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Mapping
import logging import logging
from typing import Any from typing import Any
from urllib.parse import urlparse from urllib.parse import urlparse
@ -8,14 +9,14 @@ from urllib.parse import urlparse
from aiowebostv import WebOsTvPairError from aiowebostv import WebOsTvPairError
import voluptuous as vol import voluptuous as vol
from homeassistant import config_entries, data_entry_flow
from homeassistant.components import ssdp from homeassistant.components import ssdp
from homeassistant.config_entries import ConfigEntry, ConfigFlow, OptionsFlow
from homeassistant.const import CONF_CLIENT_SECRET, CONF_HOST, CONF_NAME from homeassistant.const import CONF_CLIENT_SECRET, CONF_HOST, CONF_NAME
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.data_entry_flow import FlowResult from homeassistant.data_entry_flow import AbortFlow, FlowResult
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from . import async_control_connect from . import async_control_connect, update_client_key
from .const import CONF_SOURCES, DEFAULT_NAME, DOMAIN, WEBOSTV_EXCEPTIONS from .const import CONF_SOURCES, DEFAULT_NAME, DOMAIN, WEBOSTV_EXCEPTIONS
from .helpers import async_get_sources from .helpers import async_get_sources
@ -30,7 +31,7 @@ DATA_SCHEMA = vol.Schema(
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN): class FlowHandler(ConfigFlow, domain=DOMAIN):
"""WebosTV configuration flow.""" """WebosTV configuration flow."""
VERSION = 1 VERSION = 1
@ -40,12 +41,11 @@ class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
self._host: str = "" self._host: str = ""
self._name: str = "" self._name: str = ""
self._uuid: str | None = None self._uuid: str | None = None
self._entry: ConfigEntry | None = None
@staticmethod @staticmethod
@callback @callback
def async_get_options_flow( def async_get_options_flow(config_entry: ConfigEntry) -> OptionsFlow:
config_entry: config_entries.ConfigEntry,
) -> OptionsFlowHandler:
"""Get the options flow for this handler.""" """Get the options flow for this handler."""
return OptionsFlowHandler(config_entry) return OptionsFlowHandler(config_entry)
@ -78,7 +78,7 @@ class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
) )
self.hass.config_entries.async_update_entry(entry, unique_id=self._uuid) self.hass.config_entries.async_update_entry(entry, unique_id=self._uuid)
raise data_entry_flow.AbortFlow("already_configured") raise AbortFlow("already_configured")
async def async_step_pairing( async def async_step_pairing(
self, user_input: dict[str, Any] | None = None self, user_input: dict[str, Any] | None = None
@ -129,11 +129,37 @@ class FlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
self._uuid = uuid self._uuid = uuid
return await self.async_step_pairing() return await self.async_step_pairing()
async def async_step_reauth(self, entry_data: Mapping[str, Any]) -> FlowResult:
"""Perform reauth upon an WebOsTvPairError."""
self._host = entry_data[CONF_HOST]
self._entry = self.hass.config_entries.async_get_entry(self.context["entry_id"])
return await self.async_step_reauth_confirm()
class OptionsFlowHandler(config_entries.OptionsFlow): async def async_step_reauth_confirm(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Dialog that informs the user that reauth is required."""
assert self._entry is not None
if user_input is not None:
try:
client = await async_control_connect(self._host, None)
except WebOsTvPairError:
return self.async_abort(reason="error_pairing")
except WEBOSTV_EXCEPTIONS:
return self.async_abort(reason="reauth_unsuccessful")
update_client_key(self.hass, self._entry, client)
await self.hass.config_entries.async_reload(self._entry.entry_id)
return self.async_abort(reason="reauth_successful")
return self.async_show_form(step_id="reauth_confirm")
class OptionsFlowHandler(OptionsFlow):
"""Handle options.""" """Handle options."""
def __init__(self, config_entry: config_entries.ConfigEntry) -> None: def __init__(self, config_entry: ConfigEntry) -> None:
"""Initialize options flow.""" """Initialize options flow."""
self.config_entry = config_entry self.config_entry = config_entry
self.options = config_entry.options self.options = config_entry.options

View File

@ -39,6 +39,7 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.restore_state import RestoreEntity from homeassistant.helpers.restore_state import RestoreEntity
from homeassistant.helpers.trigger import PluggableAction from homeassistant.helpers.trigger import PluggableAction
from . import update_client_key
from .const import ( from .const import (
ATTR_PAYLOAD, ATTR_PAYLOAD,
ATTR_SOUND_OUTPUT, ATTR_SOUND_OUTPUT,
@ -73,18 +74,11 @@ SCAN_INTERVAL = timedelta(seconds=10)
async def async_setup_entry( async def async_setup_entry(
hass: HomeAssistant, hass: HomeAssistant, entry: ConfigEntry, async_add_entities: AddEntitiesCallback
config_entry: ConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None: ) -> None:
"""Set up the LG webOS Smart TV platform.""" """Set up the LG webOS Smart TV platform."""
unique_id = config_entry.unique_id client = hass.data[DOMAIN][DATA_CONFIG_ENTRY][entry.entry_id]
assert unique_id async_add_entities([LgWebOSMediaPlayerEntity(entry, client)])
name = config_entry.title
sources = config_entry.options.get(CONF_SOURCES)
client = hass.data[DOMAIN][DATA_CONFIG_ENTRY][config_entry.entry_id]
async_add_entities([LgWebOSMediaPlayerEntity(client, name, sources, unique_id)])
_T = TypeVar("_T", bound="LgWebOSMediaPlayerEntity") _T = TypeVar("_T", bound="LgWebOSMediaPlayerEntity")
@ -123,19 +117,14 @@ class LgWebOSMediaPlayerEntity(RestoreEntity, MediaPlayerEntity):
_attr_device_class = MediaPlayerDeviceClass.TV _attr_device_class = MediaPlayerDeviceClass.TV
def __init__( def __init__(self, entry: ConfigEntry, client: WebOsClient) -> None:
self,
client: WebOsClient,
name: str,
sources: list[str] | None,
unique_id: str,
) -> None:
"""Initialize the webos device.""" """Initialize the webos device."""
self._entry = entry
self._client = client self._client = client
self._attr_assumed_state = True self._attr_assumed_state = True
self._attr_name = name self._attr_name = entry.title
self._attr_unique_id = unique_id self._attr_unique_id = entry.unique_id
self._sources = sources self._sources = entry.options.get(CONF_SOURCES)
# Assume that the TV is not paused # Assume that the TV is not paused
self._paused = False self._paused = False
@ -326,7 +315,12 @@ class LgWebOSMediaPlayerEntity(RestoreEntity, MediaPlayerEntity):
return return
with suppress(*WEBOSTV_EXCEPTIONS, WebOsTvPairError): with suppress(*WEBOSTV_EXCEPTIONS, WebOsTvPairError):
try:
await self._client.connect() await self._client.connect()
except WebOsTvPairError:
self._entry.async_start_reauth(self.hass)
else:
update_client_key(self.hass, self._entry, self._client)
@property @property
def supported_features(self) -> MediaPlayerEntityFeature: def supported_features(self) -> MediaPlayerEntityFeature:

View File

@ -13,6 +13,10 @@
"pairing": { "pairing": {
"title": "webOS TV Pairing", "title": "webOS TV Pairing",
"description": "Click submit and accept the pairing request on your TV.\n\n![Image](/static/images/config_webos.png)" "description": "Click submit and accept the pairing request on your TV.\n\n![Image](/static/images/config_webos.png)"
},
"reauth_confirm": {
"title": "webOS TV Pairing",
"description": "Click submit and accept the pairing request on your TV.\n\n![Image](/static/images/config_webos.png)"
} }
}, },
"error": { "error": {
@ -21,7 +25,9 @@
"abort": { "abort": {
"error_pairing": "Connected to LG webOS TV but not paired", "error_pairing": "Connected to LG webOS TV but not paired",
"already_in_progress": "[%key:common::config_flow::abort::already_in_progress%]", "already_in_progress": "[%key:common::config_flow::abort::already_in_progress%]",
"already_configured": "[%key:common::config_flow::abort::already_configured_device%]" "already_configured": "[%key:common::config_flow::abort::already_configured_device%]",
"reauth_successful": "[%key:common::config_flow::abort::reauth_successful%]",
"reauth_unsuccessful": "Re-authentication was unsuccessful, please turn on your TV and try again."
} }
}, },
"options": { "options": {

View File

@ -3,7 +3,9 @@
"abort": { "abort": {
"already_configured": "Device is already configured", "already_configured": "Device is already configured",
"already_in_progress": "Configuration flow is already in progress", "already_in_progress": "Configuration flow is already in progress",
"error_pairing": "Connected to LG webOS TV but not paired" "error_pairing": "Connected to LG webOS TV but not paired",
"reauth_successful": "Re-authentication was successful",
"reauth_unsuccessful": "Re-authentication was unsuccessful, please turn on your TV and try again."
}, },
"error": { "error": {
"cannot_connect": "Failed to connect, please turn on your TV or check ip address" "cannot_connect": "Failed to connect, please turn on your TV or check ip address"
@ -14,6 +16,10 @@
"description": "Click submit and accept the pairing request on your TV.\n\n![Image](/static/images/config_webos.png)", "description": "Click submit and accept the pairing request on your TV.\n\n![Image](/static/images/config_webos.png)",
"title": "webOS TV Pairing" "title": "webOS TV Pairing"
}, },
"reauth_confirm": {
"description": "Click submit and accept the pairing request on your TV.\n\n![Image](/static/images/config_webos.png)",
"title": "webOS TV Pairing"
},
"user": { "user": {
"data": { "data": {
"host": "Host", "host": "Host",

View File

@ -9,11 +9,11 @@ from homeassistant import config_entries
from homeassistant.components import ssdp from homeassistant.components import ssdp
from homeassistant.components.webostv.const import CONF_SOURCES, DOMAIN, LIVE_TV_APP_ID from homeassistant.components.webostv.const import CONF_SOURCES, DOMAIN, LIVE_TV_APP_ID
from homeassistant.config_entries import SOURCE_SSDP from homeassistant.config_entries import SOURCE_SSDP
from homeassistant.const import CONF_HOST, CONF_NAME, CONF_SOURCE from homeassistant.const import CONF_CLIENT_SECRET, CONF_HOST, CONF_NAME, CONF_SOURCE
from homeassistant.data_entry_flow import FlowResultType from homeassistant.data_entry_flow import FlowResultType
from . import setup_webostv from . import setup_webostv
from .const import FAKE_UUID, HOST, MOCK_APPS, MOCK_INPUTS, TV_NAME from .const import CLIENT_KEY, FAKE_UUID, HOST, MOCK_APPS, MOCK_INPUTS, TV_NAME
MOCK_USER_CONFIG = { MOCK_USER_CONFIG = {
CONF_HOST: HOST, CONF_HOST: HOST,
@ -289,3 +289,64 @@ async def test_form_abort_uuid_configured(hass, client):
assert result["type"] == FlowResultType.ABORT assert result["type"] == FlowResultType.ABORT
assert result["reason"] == "already_configured" assert result["reason"] == "already_configured"
assert entry.data[CONF_HOST] == "new_host" assert entry.data[CONF_HOST] == "new_host"
async def test_reauth_successful(hass, client, monkeypatch):
"""Test that the reauthorization is successful."""
entry = await setup_webostv(hass)
assert client
result = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": config_entries.SOURCE_REAUTH, "entry_id": entry.entry_id},
data=entry.data,
)
assert result["step_id"] == "reauth_confirm"
result = await hass.config_entries.flow.async_configure(result["flow_id"])
assert result["type"] == FlowResultType.FORM
assert result["step_id"] == "reauth_confirm"
assert entry.data[CONF_CLIENT_SECRET] == CLIENT_KEY
monkeypatch.setattr(client, "client_key", "new_key")
result = await hass.config_entries.flow.async_configure(
result["flow_id"], user_input={}
)
assert result["type"] == FlowResultType.ABORT
assert result["reason"] == "reauth_successful"
assert entry.data[CONF_CLIENT_SECRET] == "new_key"
@pytest.mark.parametrize(
"side_effect,reason",
[
(WebOsTvPairError, "error_pairing"),
(ConnectionRefusedError, "reauth_unsuccessful"),
],
)
async def test_reauth_errors(hass, client, monkeypatch, side_effect, reason):
"""Test reauthorization errors."""
entry = await setup_webostv(hass)
assert client
result = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": config_entries.SOURCE_REAUTH, "entry_id": entry.entry_id},
data=entry.data,
)
assert result["step_id"] == "reauth_confirm"
result = await hass.config_entries.flow.async_configure(result["flow_id"])
assert result["type"] == FlowResultType.FORM
assert result["step_id"] == "reauth_confirm"
monkeypatch.setattr(client, "connect", Mock(side_effect=side_effect))
result = await hass.config_entries.flow.async_configure(
result["flow_id"], user_input={}
)
assert result["type"] == FlowResultType.ABORT
assert result["reason"] == reason

View File

@ -0,0 +1,39 @@
"""The tests for the LG webOS TV platform."""
from unittest.mock import Mock
from aiowebostv import WebOsTvPairError
from homeassistant.components.webostv.const import DOMAIN
from homeassistant.config_entries import SOURCE_REAUTH, ConfigEntryState
from homeassistant.const import CONF_CLIENT_SECRET
from . import setup_webostv
async def test_reauth_setup_entry(hass, client, monkeypatch):
"""Test reauth flow triggered by setup entry."""
monkeypatch.setattr(client, "is_connected", Mock(return_value=False))
monkeypatch.setattr(client, "connect", Mock(side_effect=WebOsTvPairError))
entry = await setup_webostv(hass)
assert entry.state == ConfigEntryState.SETUP_ERROR
flows = hass.config_entries.flow.async_progress()
assert len(flows) == 1
flow = flows[0]
assert flow.get("step_id") == "reauth_confirm"
assert flow.get("handler") == DOMAIN
assert "context" in flow
assert flow["context"].get("source") == SOURCE_REAUTH
assert flow["context"].get("entry_id") == entry.entry_id
async def test_key_update_setup_entry(hass, client, monkeypatch):
"""Test key update from setup entry."""
monkeypatch.setattr(client, "client_key", "new_key")
entry = await setup_webostv(hass)
assert entry.state == ConfigEntryState.LOADED
assert entry.data[CONF_CLIENT_SECRET] == "new_key"

View File

@ -4,6 +4,7 @@ from datetime import timedelta
from http import HTTPStatus from http import HTTPStatus
from unittest.mock import Mock from unittest.mock import Mock
from aiowebostv import WebOsTvPairError
import pytest import pytest
from homeassistant.components import automation from homeassistant.components import automation
@ -37,6 +38,7 @@ from homeassistant.components.webostv.media_player import (
SUPPORT_WEBOSTV, SUPPORT_WEBOSTV,
SUPPORT_WEBOSTV_VOLUME, SUPPORT_WEBOSTV_VOLUME,
) )
from homeassistant.config_entries import SOURCE_REAUTH, ConfigEntryState
from homeassistant.const import ( from homeassistant.const import (
ATTR_COMMAND, ATTR_COMMAND,
ATTR_DEVICE_CLASS, ATTR_DEVICE_CLASS,
@ -763,3 +765,28 @@ async def test_get_image_https(
content = await resp.read() content = await resp.read()
assert content == b"https_image" assert content == b"https_image"
async def test_reauth_reconnect(hass, client, monkeypatch):
"""Test reauth flow triggered by reconnect."""
entry = await setup_webostv(hass)
monkeypatch.setattr(client, "is_connected", Mock(return_value=False))
monkeypatch.setattr(client, "connect", Mock(side_effect=WebOsTvPairError))
assert entry.state == ConfigEntryState.LOADED
async_fire_time_changed(hass, dt.utcnow() + timedelta(seconds=20))
await hass.async_block_till_done()
assert entry.state == ConfigEntryState.LOADED
flows = hass.config_entries.flow.async_progress()
assert len(flows) == 1
flow = flows[0]
assert flow.get("step_id") == "reauth_confirm"
assert flow.get("handler") == DOMAIN
assert "context" in flow
assert flow["context"].get("source") == SOURCE_REAUTH
assert flow["context"].get("entry_id") == entry.entry_id