mirror of
https://github.com/home-assistant/core.git
synced 2025-04-24 01:08:12 +00:00
Introduce new OAuth2 config flow helper (#27727)
* Refactor the Somfy auth implementation * Typing * Migrate Somfy to OAuth2 flow helper * Add tests * Add more tests * Fix tests * Fix type error * More tests * Remove side effect from constructor * implementation -> auth_implementation * Make get_implementation async * Minor cleanup + Allow picking implementations. * Add support for extra authorize data
This commit is contained in:
parent
6157be23dc
commit
b6c26cb363
@ -260,7 +260,7 @@ def _get_domains(hass: core.HomeAssistant, config: Dict[str, Any]) -> Set[str]:
|
||||
domains = set(key.split(" ")[0] for key in config.keys() if key != core.DOMAIN)
|
||||
|
||||
# Add config entry domains
|
||||
domains.update(hass.config_entries.async_domains()) # type: ignore
|
||||
domains.update(hass.config_entries.async_domains())
|
||||
|
||||
# Make sure the Hass.io component is loaded
|
||||
if "HASSIO" in os.environ:
|
||||
|
@ -4,21 +4,21 @@ Support for Somfy hubs.
|
||||
For more details about this component, please refer to the documentation at
|
||||
https://home-assistant.io/integrations/somfy/
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import timedelta
|
||||
from functools import partial
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.helpers import config_validation as cv, config_entry_oauth2_flow
|
||||
from homeassistant.components.somfy import config_flow
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import CONF_TOKEN
|
||||
from homeassistant.helpers.entity import Entity
|
||||
from homeassistant.helpers.typing import HomeAssistantType
|
||||
from homeassistant.util import Throttle
|
||||
|
||||
from . import api
|
||||
|
||||
API = "api"
|
||||
|
||||
DEVICES = "devices"
|
||||
@ -52,19 +52,21 @@ SOMFY_COMPONENTS = ["cover"]
|
||||
|
||||
async def async_setup(hass, config):
|
||||
"""Set up the Somfy component."""
|
||||
hass.data[DOMAIN] = {}
|
||||
|
||||
if DOMAIN not in config:
|
||||
return True
|
||||
|
||||
hass.data[DOMAIN] = {}
|
||||
|
||||
config_flow.register_flow_implementation(
|
||||
hass, config[DOMAIN][CONF_CLIENT_ID], config[DOMAIN][CONF_CLIENT_SECRET]
|
||||
)
|
||||
|
||||
hass.async_create_task(
|
||||
hass.config_entries.flow.async_init(
|
||||
DOMAIN, context={"source": config_entries.SOURCE_IMPORT}
|
||||
)
|
||||
config_flow.SomfyFlowHandler.async_register_implementation(
|
||||
hass,
|
||||
config_entry_oauth2_flow.LocalOAuth2Implementation(
|
||||
hass,
|
||||
DOMAIN,
|
||||
config[DOMAIN][CONF_CLIENT_ID],
|
||||
config[DOMAIN][CONF_CLIENT_SECRET],
|
||||
"https://accounts.somfy.com/oauth/oauth/v2/auth",
|
||||
"https://accounts.somfy.com/oauth/oauth/v2/token",
|
||||
),
|
||||
)
|
||||
|
||||
return True
|
||||
@ -72,25 +74,18 @@ async def async_setup(hass, config):
|
||||
|
||||
async def async_setup_entry(hass: HomeAssistantType, entry: ConfigEntry):
|
||||
"""Set up Somfy from a config entry."""
|
||||
|
||||
def token_saver(token):
|
||||
_LOGGER.debug("Saving updated token")
|
||||
entry.data[CONF_TOKEN] = token
|
||||
update_entry = partial(
|
||||
hass.config_entries.async_update_entry, data={**entry.data}
|
||||
# Backwards compat
|
||||
if "auth_implementation" not in entry.data:
|
||||
hass.config_entries.async_update_entry(
|
||||
entry, data={**entry.data, "auth_implementation": DOMAIN}
|
||||
)
|
||||
hass.add_job(update_entry, entry)
|
||||
|
||||
# Force token update.
|
||||
from pymfy.api.somfy_api import SomfyApi
|
||||
|
||||
hass.data[DOMAIN][API] = SomfyApi(
|
||||
entry.data["refresh_args"]["client_id"],
|
||||
entry.data["refresh_args"]["client_secret"],
|
||||
token=entry.data[CONF_TOKEN],
|
||||
token_updater=token_saver,
|
||||
implementation = await config_entry_oauth2_flow.async_get_config_entry_implementation(
|
||||
hass, entry
|
||||
)
|
||||
|
||||
hass.data[DOMAIN][API] = api.ConfigEntrySomfyApi(hass, entry, implementation)
|
||||
|
||||
await update_all_devices(hass)
|
||||
|
||||
for component in SOMFY_COMPONENTS:
|
||||
@ -104,16 +99,22 @@ async def async_setup_entry(hass: HomeAssistantType, entry: ConfigEntry):
|
||||
async def async_unload_entry(hass: HomeAssistantType, entry: ConfigEntry):
|
||||
"""Unload a config entry."""
|
||||
hass.data[DOMAIN].pop(API, None)
|
||||
await asyncio.gather(
|
||||
*[
|
||||
hass.config_entries.async_forward_entry_unload(entry, component)
|
||||
for component in SOMFY_COMPONENTS
|
||||
]
|
||||
)
|
||||
return True
|
||||
|
||||
|
||||
class SomfyEntity(Entity):
|
||||
"""Representation of a generic Somfy device."""
|
||||
|
||||
def __init__(self, device, api):
|
||||
def __init__(self, device, somfy_api):
|
||||
"""Initialize the Somfy device."""
|
||||
self.device = device
|
||||
self.api = api
|
||||
self.api = somfy_api
|
||||
|
||||
@property
|
||||
def unique_id(self):
|
||||
|
55
homeassistant/components/somfy/api.py
Normal file
55
homeassistant/components/somfy/api.py
Normal file
@ -0,0 +1,55 @@
|
||||
"""API for Somfy bound to HASS OAuth."""
|
||||
from asyncio import run_coroutine_threadsafe
|
||||
from functools import partial
|
||||
|
||||
import requests
|
||||
from pymfy.api import somfy_api
|
||||
|
||||
from homeassistant import core, config_entries
|
||||
from homeassistant.helpers import config_entry_oauth2_flow
|
||||
|
||||
|
||||
class ConfigEntrySomfyApi(somfy_api.AbstractSomfyApi):
|
||||
"""Provide a Somfy API tied into an OAuth2 based config entry."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hass: core.HomeAssistant,
|
||||
config_entry: config_entries.ConfigEntry,
|
||||
implementation: config_entry_oauth2_flow.AbstractOAuth2Implementation,
|
||||
):
|
||||
"""Initialize the Config Entry Somfy API."""
|
||||
self.hass = hass
|
||||
self.config_entry = config_entry
|
||||
self.session = config_entry_oauth2_flow.OAuth2Session(
|
||||
hass, config_entry, implementation
|
||||
)
|
||||
|
||||
def get(self, path):
|
||||
"""Fetch a URL from the Somfy API."""
|
||||
return run_coroutine_threadsafe(
|
||||
self._request("get", path), self.hass.loop
|
||||
).result()
|
||||
|
||||
def post(self, path, *, json):
|
||||
"""Post data to the Somfy API."""
|
||||
return run_coroutine_threadsafe(
|
||||
self._request("post", path, json=json), self.hass.loop
|
||||
).result()
|
||||
|
||||
async def _request(self, method, path, **kwargs):
|
||||
"""Make a request."""
|
||||
await self.session.async_ensure_token_valid()
|
||||
|
||||
return await self.hass.async_add_executor_job(
|
||||
partial(
|
||||
requests.request,
|
||||
method,
|
||||
f"{self.base_url}{path}",
|
||||
**kwargs,
|
||||
headers={
|
||||
**kwargs.get("headers", {}),
|
||||
"authorization": f"Bearer {self.config_entry.data['token']['access_token']}",
|
||||
},
|
||||
)
|
||||
)
|
@ -1,141 +1,28 @@
|
||||
"""Config flow for Somfy."""
|
||||
import asyncio
|
||||
import logging
|
||||
|
||||
import async_timeout
|
||||
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
from homeassistant.core import callback
|
||||
from .const import CLIENT_ID, CLIENT_SECRET, DOMAIN
|
||||
|
||||
AUTH_CALLBACK_PATH = "/auth/somfy/callback"
|
||||
AUTH_CALLBACK_NAME = "auth:somfy:callback"
|
||||
from homeassistant.helpers import config_entry_oauth2_flow
|
||||
from .const import DOMAIN
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@callback
|
||||
def register_flow_implementation(hass, client_id, client_secret):
|
||||
"""Register a flow implementation.
|
||||
@config_entries.HANDLERS.register(DOMAIN)
|
||||
class SomfyFlowHandler(config_entry_oauth2_flow.AbstractOAuth2FlowHandler):
|
||||
"""Config flow to handle Somfy OAuth2 authentication."""
|
||||
|
||||
client_id: Client id.
|
||||
client_secret: Client secret.
|
||||
"""
|
||||
hass.data[DOMAIN][CLIENT_ID] = client_id
|
||||
hass.data[DOMAIN][CLIENT_SECRET] = client_secret
|
||||
|
||||
|
||||
@config_entries.HANDLERS.register("somfy")
|
||||
class SomfyFlowHandler(config_entries.ConfigFlow):
|
||||
"""Handle a config flow."""
|
||||
|
||||
VERSION = 1
|
||||
DOMAIN = DOMAIN
|
||||
CONNECTION_CLASS = config_entries.CONN_CLASS_CLOUD_POLL
|
||||
|
||||
def __init__(self):
|
||||
"""Instantiate config flow."""
|
||||
self.code = None
|
||||
|
||||
async def async_step_import(self, user_input=None):
|
||||
"""Handle external yaml configuration."""
|
||||
if self.hass.config_entries.async_entries(DOMAIN):
|
||||
return self.async_abort(reason="already_setup")
|
||||
return await self.async_step_auth()
|
||||
@property
|
||||
def logger(self) -> logging.Logger:
|
||||
"""Return logger."""
|
||||
return logging.getLogger(__name__)
|
||||
|
||||
async def async_step_user(self, user_input=None):
|
||||
"""Handle a flow start."""
|
||||
if self.hass.config_entries.async_entries(DOMAIN):
|
||||
return self.async_abort(reason="already_setup")
|
||||
|
||||
if DOMAIN not in self.hass.data:
|
||||
return self.async_abort(reason="missing_configuration")
|
||||
|
||||
return await self.async_step_auth()
|
||||
|
||||
async def async_step_auth(self, user_input=None):
|
||||
"""Create an entry for auth."""
|
||||
# Flow has been triggered from Somfy website
|
||||
if user_input:
|
||||
return await self.async_step_code(user_input)
|
||||
|
||||
try:
|
||||
with async_timeout.timeout(10):
|
||||
url, _ = await self._get_authorization_url()
|
||||
except asyncio.TimeoutError:
|
||||
return self.async_abort(reason="authorize_url_timeout")
|
||||
|
||||
return self.async_external_step(step_id="auth", url=url)
|
||||
|
||||
async def _get_authorization_url(self):
|
||||
"""Get Somfy authorization url."""
|
||||
from pymfy.api.somfy_api import SomfyApi
|
||||
|
||||
client_id = self.hass.data[DOMAIN][CLIENT_ID]
|
||||
client_secret = self.hass.data[DOMAIN][CLIENT_SECRET]
|
||||
redirect_uri = f"{self.hass.config.api.base_url}{AUTH_CALLBACK_PATH}"
|
||||
api = SomfyApi(client_id, client_secret, redirect_uri)
|
||||
|
||||
self.hass.http.register_view(SomfyAuthCallbackView())
|
||||
# Thanks to the state, we can forward the flow id to Somfy that will
|
||||
# add it in the callback.
|
||||
return await self.hass.async_add_executor_job(
|
||||
api.get_authorization_url, self.flow_id
|
||||
)
|
||||
|
||||
async def async_step_code(self, code):
|
||||
"""Received code for authentication."""
|
||||
self.code = code
|
||||
return self.async_external_step_done(next_step_id="creation")
|
||||
|
||||
async def async_step_creation(self, user_input=None):
|
||||
"""Create Somfy api and entries."""
|
||||
client_id = self.hass.data[DOMAIN][CLIENT_ID]
|
||||
client_secret = self.hass.data[DOMAIN][CLIENT_SECRET]
|
||||
code = self.code
|
||||
from pymfy.api.somfy_api import SomfyApi
|
||||
|
||||
redirect_uri = f"{self.hass.config.api.base_url}{AUTH_CALLBACK_PATH}"
|
||||
api = SomfyApi(client_id, client_secret, redirect_uri)
|
||||
token = await self.hass.async_add_executor_job(api.request_token, None, code)
|
||||
_LOGGER.info("Successfully authenticated Somfy")
|
||||
return self.async_create_entry(
|
||||
title="Somfy",
|
||||
data={
|
||||
"token": token,
|
||||
"refresh_args": {
|
||||
"client_id": client_id,
|
||||
"client_secret": client_secret,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class SomfyAuthCallbackView(HomeAssistantView):
|
||||
"""Somfy Authorization Callback View."""
|
||||
|
||||
requires_auth = False
|
||||
url = AUTH_CALLBACK_PATH
|
||||
name = AUTH_CALLBACK_NAME
|
||||
|
||||
@staticmethod
|
||||
async def get(request):
|
||||
"""Receive authorization code."""
|
||||
from aiohttp import web_response
|
||||
|
||||
if "code" not in request.query or "state" not in request.query:
|
||||
return web_response.Response(
|
||||
text="Missing code or state parameter in " + request.url
|
||||
)
|
||||
|
||||
hass = request.app["hass"]
|
||||
hass.async_create_task(
|
||||
hass.config_entries.flow.async_configure(
|
||||
flow_id=request.query["state"], user_input=request.query["code"]
|
||||
)
|
||||
)
|
||||
|
||||
return web_response.Response(
|
||||
headers={"content-type": "text/html"},
|
||||
text="<script>window.close()</script>",
|
||||
)
|
||||
return await super().async_step_user(user_input)
|
||||
|
@ -1,5 +1,3 @@
|
||||
"""Define constants for the Somfy component."""
|
||||
|
||||
DOMAIN = "somfy"
|
||||
CLIENT_ID = "client_id"
|
||||
CLIENT_SECRET = "client_secret"
|
||||
|
@ -3,11 +3,7 @@
|
||||
"name": "Somfy Open API",
|
||||
"config_flow": true,
|
||||
"documentation": "https://www.home-assistant.io/integrations/somfy",
|
||||
"dependencies": [],
|
||||
"codeowners": [
|
||||
"@tetienne"
|
||||
],
|
||||
"requirements": [
|
||||
"pymfy==0.5.2"
|
||||
]
|
||||
}
|
||||
"dependencies": ["http"],
|
||||
"codeowners": ["@tetienne"],
|
||||
"requirements": ["pymfy==0.6.0"]
|
||||
}
|
||||
|
@ -337,7 +337,7 @@ class ConfigEntry:
|
||||
return False
|
||||
if result:
|
||||
# pylint: disable=protected-access
|
||||
hass.config_entries._async_schedule_save() # type: ignore
|
||||
hass.config_entries._async_schedule_save()
|
||||
return result
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception(
|
||||
|
@ -77,7 +77,8 @@ from homeassistant.util.unit_system import ( # NOQA
|
||||
# Typing imports that create a circular dependency
|
||||
# pylint: disable=using-constant-test
|
||||
if TYPE_CHECKING:
|
||||
from homeassistant.config_entries import ConfigEntries # noqa
|
||||
from homeassistant.config_entries import ConfigEntries
|
||||
from homeassistant.components.http import HomeAssistantHTTP
|
||||
|
||||
# pylint: disable=invalid-name
|
||||
T = TypeVar("T")
|
||||
@ -162,6 +163,9 @@ class CoreState(enum.Enum):
|
||||
class HomeAssistant:
|
||||
"""Root object of the Home Assistant home automation."""
|
||||
|
||||
http: "HomeAssistantHTTP" = None # type: ignore
|
||||
config_entries: "ConfigEntries" = None # type: ignore
|
||||
|
||||
def __init__(self, loop: Optional[asyncio.events.AbstractEventLoop] = None) -> None:
|
||||
"""Initialize new Home Assistant object."""
|
||||
self.loop: asyncio.events.AbstractEventLoop = (loop or asyncio.get_event_loop())
|
||||
@ -186,9 +190,6 @@ class HomeAssistant:
|
||||
self.data: dict = {}
|
||||
self.state = CoreState.not_running
|
||||
self.exit_code = 0
|
||||
self.config_entries: Optional[
|
||||
ConfigEntries # pylint: disable=used-before-assignment
|
||||
] = None
|
||||
# If not None, use to signal end-of-loop
|
||||
self._stopped: Optional[asyncio.Event] = None
|
||||
|
||||
|
@ -168,7 +168,7 @@ class FlowHandler:
|
||||
"""Handle the configuration flow of a component."""
|
||||
|
||||
# Set by flow manager
|
||||
flow_id: Optional[str] = None
|
||||
flow_id: str = None # type: ignore
|
||||
hass: Optional[HomeAssistant] = None
|
||||
handler: Optional[Hashable] = None
|
||||
cur_step: Optional[Dict[str, str]] = None
|
||||
|
420
homeassistant/helpers/config_entry_oauth2_flow.py
Normal file
420
homeassistant/helpers/config_entry_oauth2_flow.py
Normal file
@ -0,0 +1,420 @@
|
||||
"""Config Flow using OAuth2.
|
||||
|
||||
This module exists of the following parts:
|
||||
- OAuth2 config flow which supports multiple OAuth2 implementations
|
||||
- OAuth2 implementation that works with local provided client ID/secret
|
||||
|
||||
"""
|
||||
import asyncio
|
||||
from abc import ABCMeta, ABC, abstractmethod
|
||||
import logging
|
||||
from typing import Optional, Any, Dict, cast
|
||||
import time
|
||||
|
||||
import async_timeout
|
||||
from aiohttp import web, client
|
||||
import jwt
|
||||
import voluptuous as vol
|
||||
from yarl import URL
|
||||
|
||||
from homeassistant.auth.util import generate_secret
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant import config_entries
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
|
||||
from .aiohttp_client import async_get_clientsession
|
||||
|
||||
|
||||
DATA_JWT_SECRET = "oauth2_jwt_secret"
|
||||
DATA_VIEW_REGISTERED = "oauth2_view_reg"
|
||||
DATA_IMPLEMENTATIONS = "oauth2_impl"
|
||||
AUTH_CALLBACK_PATH = "/auth/external/callback"
|
||||
|
||||
|
||||
class AbstractOAuth2Implementation(ABC):
|
||||
"""Base class to abstract OAuth2 authentication."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Name of the implementation."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def domain(self) -> str:
|
||||
"""Domain that is providing the implementation."""
|
||||
|
||||
@abstractmethod
|
||||
async def async_generate_authorize_url(self, flow_id: str) -> str:
|
||||
"""Generate a url for the user to authorize.
|
||||
|
||||
This step is called when a config flow is initialized. It should redirect the
|
||||
user to the vendor website where they can authorize Home Assistant.
|
||||
|
||||
The implementation is responsible to get notified when the user is authorized
|
||||
and pass this to the specified config flow. Do as little work as possible once
|
||||
notified. You can do the work inside async_resolve_external_data. This will
|
||||
give the best UX.
|
||||
|
||||
Pass external data in with:
|
||||
|
||||
```python
|
||||
await hass.config_entries.flow.async_configure(
|
||||
flow_id=flow_id, user_input=external_data
|
||||
)
|
||||
```
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def async_resolve_external_data(self, external_data: Any) -> dict:
|
||||
"""Resolve external data to tokens.
|
||||
|
||||
Turn the data that the implementation passed to the config flow as external
|
||||
step data into tokens. These tokens will be stored as 'token' in the
|
||||
config entry data.
|
||||
"""
|
||||
|
||||
async def async_refresh_token(self, token: dict) -> dict:
|
||||
"""Refresh a token and update expires info."""
|
||||
new_token = await self._async_refresh_token(token)
|
||||
new_token["expires_at"] = time.time() + new_token["expires_in"]
|
||||
return new_token
|
||||
|
||||
@abstractmethod
|
||||
async def _async_refresh_token(self, token: dict) -> dict:
|
||||
"""Refresh a token."""
|
||||
|
||||
|
||||
class LocalOAuth2Implementation(AbstractOAuth2Implementation):
|
||||
"""Local OAuth2 implementation."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
domain: str,
|
||||
client_id: str,
|
||||
client_secret: str,
|
||||
authorize_url: str,
|
||||
token_url: str,
|
||||
):
|
||||
"""Initialize local auth implementation."""
|
||||
self.hass = hass
|
||||
self._domain = domain
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.authorize_url = authorize_url
|
||||
self.token_url = token_url
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Name of the implementation."""
|
||||
return "Configuration.yaml"
|
||||
|
||||
@property
|
||||
def domain(self) -> str:
|
||||
"""Domain providing the implementation."""
|
||||
return self._domain
|
||||
|
||||
@property
|
||||
def redirect_uri(self) -> str:
|
||||
"""Return the redirect uri."""
|
||||
return f"{self.hass.config.api.base_url}{AUTH_CALLBACK_PATH}" # type: ignore
|
||||
|
||||
async def async_generate_authorize_url(self, flow_id: str) -> str:
|
||||
"""Generate a url for the user to authorize."""
|
||||
return str(
|
||||
URL(self.authorize_url).with_query(
|
||||
{
|
||||
"response_type": "code",
|
||||
"client_id": self.client_id,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"state": _encode_jwt(self.hass, {"flow_id": flow_id}),
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
async def async_resolve_external_data(self, external_data: Any) -> dict:
|
||||
"""Resolve the authorization code to tokens."""
|
||||
return await self._token_request(
|
||||
{
|
||||
"grant_type": "authorization_code",
|
||||
"code": external_data,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
}
|
||||
)
|
||||
|
||||
async def _async_refresh_token(self, token: dict) -> dict:
|
||||
"""Refresh tokens."""
|
||||
new_token = await self._token_request(
|
||||
{
|
||||
"grant_type": "refresh_token",
|
||||
"client_id": self.client_id,
|
||||
"refresh_token": token["refresh_token"],
|
||||
}
|
||||
)
|
||||
return {**token, **new_token}
|
||||
|
||||
async def _token_request(self, data: dict) -> dict:
|
||||
"""Make a token request."""
|
||||
session = async_get_clientsession(self.hass)
|
||||
|
||||
data["client_id"] = self.client_id
|
||||
|
||||
if self.client_secret is not None:
|
||||
data["client_secret"] = self.client_secret
|
||||
|
||||
resp = await session.post(self.token_url, data=data)
|
||||
resp.raise_for_status()
|
||||
return cast(dict, await resp.json())
|
||||
|
||||
|
||||
class AbstractOAuth2FlowHandler(config_entries.ConfigFlow, metaclass=ABCMeta):
|
||||
"""Handle a config flow."""
|
||||
|
||||
DOMAIN = ""
|
||||
|
||||
VERSION = 1
|
||||
CONNECTION_CLASS = config_entries.CONN_CLASS_UNKNOWN
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Instantiate config flow."""
|
||||
if self.DOMAIN == "":
|
||||
raise TypeError(
|
||||
f"Can't instantiate class {self.__class__.__name__} without DOMAIN being set"
|
||||
)
|
||||
|
||||
self.external_data: Any = None
|
||||
self.flow_impl: AbstractOAuth2Implementation = None # type: ignore
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def logger(self) -> logging.Logger:
|
||||
"""Return logger."""
|
||||
|
||||
@property
|
||||
def extra_authorize_data(self) -> dict:
|
||||
"""Extra data that needs to be appended to the authorize url."""
|
||||
return {}
|
||||
|
||||
async def async_step_pick_implementation(self, user_input: dict = None) -> dict:
|
||||
"""Handle a flow start."""
|
||||
assert self.hass
|
||||
implementations = await async_get_implementations(self.hass, self.DOMAIN)
|
||||
|
||||
if user_input is not None:
|
||||
self.flow_impl = implementations[user_input["implementation"]]
|
||||
return await self.async_step_auth()
|
||||
|
||||
if not implementations:
|
||||
return self.async_abort(reason="missing_configuration")
|
||||
|
||||
if len(implementations) == 1:
|
||||
# Pick first implementation as we have only one.
|
||||
self.flow_impl = list(implementations.values())[0]
|
||||
return await self.async_step_auth()
|
||||
|
||||
return self.async_show_form(
|
||||
step_id="pick_implementation",
|
||||
data_schema=vol.Schema(
|
||||
{
|
||||
vol.Required(
|
||||
"implementation", default=list(implementations.keys())[0]
|
||||
): vol.In({key: impl.name for key, impl in implementations.items()})
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
async def async_step_auth(self, user_input: dict = None) -> dict:
|
||||
"""Create an entry for auth."""
|
||||
# Flow has been triggered by external data
|
||||
if user_input:
|
||||
self.external_data = user_input
|
||||
return self.async_external_step_done(next_step_id="creation")
|
||||
|
||||
try:
|
||||
with async_timeout.timeout(10):
|
||||
url = await self.flow_impl.async_generate_authorize_url(self.flow_id)
|
||||
except asyncio.TimeoutError:
|
||||
return self.async_abort(reason="authorize_url_timeout")
|
||||
|
||||
url = str(URL(url).update_query(self.extra_authorize_data))
|
||||
|
||||
return self.async_external_step(step_id="auth", url=url)
|
||||
|
||||
async def async_step_creation(self, user_input: dict = None) -> dict:
|
||||
"""Create config entry from external data."""
|
||||
token = await self.flow_impl.async_resolve_external_data(self.external_data)
|
||||
token["expires_at"] = time.time() + token["expires_in"]
|
||||
|
||||
self.logger.info("Successfully authenticated")
|
||||
|
||||
return await self.async_oauth_create_entry(
|
||||
{"auth_implementation": self.flow_impl.domain, "token": token}
|
||||
)
|
||||
|
||||
async def async_oauth_create_entry(self, data: dict) -> dict:
|
||||
"""Create an entry for the flow.
|
||||
|
||||
Ok to override if you want to fetch extra info or even add another step.
|
||||
"""
|
||||
return self.async_create_entry(title=self.flow_impl.name, data=data)
|
||||
|
||||
async_step_user = async_step_pick_implementation
|
||||
async_step_ssdp = async_step_pick_implementation
|
||||
async_step_zeroconf = async_step_pick_implementation
|
||||
async_step_homekit = async_step_pick_implementation
|
||||
|
||||
@classmethod
|
||||
def async_register_implementation(
|
||||
cls, hass: HomeAssistant, local_impl: LocalOAuth2Implementation
|
||||
) -> None:
|
||||
"""Register a local implementation."""
|
||||
async_register_implementation(hass, cls.DOMAIN, local_impl)
|
||||
|
||||
|
||||
@callback
|
||||
def async_register_implementation(
|
||||
hass: HomeAssistant, domain: str, implementation: AbstractOAuth2Implementation
|
||||
) -> None:
|
||||
"""Register an OAuth2 flow implementation for an integration."""
|
||||
if isinstance(implementation, LocalOAuth2Implementation) and not hass.data.get(
|
||||
DATA_VIEW_REGISTERED, False
|
||||
):
|
||||
hass.http.register_view(OAuth2AuthorizeCallbackView()) # type: ignore
|
||||
hass.data[DATA_VIEW_REGISTERED] = True
|
||||
|
||||
implementations = hass.data.setdefault(DATA_IMPLEMENTATIONS, {})
|
||||
implementations.setdefault(domain, {})[implementation.domain] = implementation
|
||||
|
||||
|
||||
async def async_get_implementations(
|
||||
hass: HomeAssistant, domain: str
|
||||
) -> Dict[str, AbstractOAuth2Implementation]:
|
||||
"""Return OAuth2 implementations for specified domain."""
|
||||
return cast(
|
||||
Dict[str, AbstractOAuth2Implementation],
|
||||
hass.data.setdefault(DATA_IMPLEMENTATIONS, {}).get(domain, {}),
|
||||
)
|
||||
|
||||
|
||||
async def async_get_config_entry_implementation(
|
||||
hass: HomeAssistant, config_entry: config_entries.ConfigEntry
|
||||
) -> AbstractOAuth2Implementation:
|
||||
"""Return the implementation for this config entry."""
|
||||
implementations = await async_get_implementations(hass, config_entry.domain)
|
||||
implementation = implementations.get(config_entry.data["auth_implementation"])
|
||||
|
||||
if implementation is None:
|
||||
raise ValueError("Implementation not available")
|
||||
|
||||
return implementation
|
||||
|
||||
|
||||
class OAuth2AuthorizeCallbackView(HomeAssistantView):
|
||||
"""OAuth2 Authorization Callback View."""
|
||||
|
||||
requires_auth = False
|
||||
url = AUTH_CALLBACK_PATH
|
||||
name = "auth:external:callback"
|
||||
|
||||
async def get(self, request: web.Request) -> web.Response:
|
||||
"""Receive authorization code."""
|
||||
if "code" not in request.query or "state" not in request.query:
|
||||
return web.Response(
|
||||
text=f"Missing code or state parameter in {request.url}"
|
||||
)
|
||||
|
||||
hass = request.app["hass"]
|
||||
|
||||
state = _decode_jwt(hass, request.query["state"])
|
||||
|
||||
if state is None:
|
||||
return web.Response(text=f"Invalid state")
|
||||
|
||||
await hass.config_entries.flow.async_configure(
|
||||
flow_id=state["flow_id"], user_input=request.query["code"]
|
||||
)
|
||||
|
||||
return web.Response(
|
||||
headers={"content-type": "text/html"},
|
||||
text="<script>window.close()</script>",
|
||||
)
|
||||
|
||||
|
||||
class OAuth2Session:
|
||||
"""Session to make requests authenticated with OAuth2."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hass: HomeAssistant,
|
||||
config_entry: config_entries.ConfigEntry,
|
||||
implementation: AbstractOAuth2Implementation,
|
||||
):
|
||||
"""Initialize an OAuth2 session."""
|
||||
self.hass = hass
|
||||
self.config_entry = config_entry
|
||||
self.implementation = implementation
|
||||
|
||||
async def async_ensure_token_valid(self) -> None:
|
||||
"""Ensure that the current token is valid."""
|
||||
token = self.config_entry.data["token"]
|
||||
|
||||
if token["expires_at"] > time.time():
|
||||
return
|
||||
|
||||
new_token = await self.implementation.async_refresh_token(token)
|
||||
|
||||
self.hass.config_entries.async_update_entry( # type: ignore
|
||||
self.config_entry, data={**self.config_entry.data, "token": new_token}
|
||||
)
|
||||
|
||||
async def async_request(
|
||||
self, method: str, url: str, **kwargs: Any
|
||||
) -> client.ClientResponse:
|
||||
"""Make a request."""
|
||||
await self.async_ensure_token_valid()
|
||||
return await async_oauth2_request(
|
||||
self.hass, self.config_entry.data["token"], method, url, **kwargs
|
||||
)
|
||||
|
||||
|
||||
async def async_oauth2_request(
|
||||
hass: HomeAssistant, token: dict, method: str, url: str, **kwargs: Any
|
||||
) -> client.ClientResponse:
|
||||
"""Make an OAuth2 authenticated request.
|
||||
|
||||
This method will not refresh tokens. Use OAuth2 session for that.
|
||||
"""
|
||||
session = async_get_clientsession(hass)
|
||||
|
||||
return await session.request(
|
||||
method,
|
||||
url,
|
||||
**kwargs,
|
||||
headers={
|
||||
**kwargs.get("headers", {}),
|
||||
"authorization": f"Bearer {token['access_token']}",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@callback
|
||||
def _encode_jwt(hass: HomeAssistant, data: dict) -> str:
|
||||
"""JWT encode data."""
|
||||
secret = hass.data.get(DATA_JWT_SECRET)
|
||||
|
||||
if secret is None:
|
||||
secret = hass.data[DATA_JWT_SECRET] = generate_secret()
|
||||
|
||||
return jwt.encode(data, secret, algorithm="HS256").decode()
|
||||
|
||||
|
||||
@callback
|
||||
def _decode_jwt(hass: HomeAssistant, encoded: str) -> Optional[dict]:
|
||||
"""JWT encode data."""
|
||||
secret = cast(str, hass.data.get(DATA_JWT_SECRET))
|
||||
|
||||
try:
|
||||
return jwt.decode(encoded, secret, algorithms=["HS256"])
|
||||
except jwt.InvalidTokenError:
|
||||
return None
|
@ -1304,7 +1304,7 @@ pymailgunner==1.4
|
||||
pymediaroom==0.6.4
|
||||
|
||||
# homeassistant.components.somfy
|
||||
pymfy==0.5.2
|
||||
pymfy==0.6.0
|
||||
|
||||
# homeassistant.components.xiaomi_tv
|
||||
pymitv==1.4.3
|
||||
|
@ -447,7 +447,7 @@ pylitejet==0.1
|
||||
pymailgunner==1.4
|
||||
|
||||
# homeassistant.components.somfy
|
||||
pymfy==0.5.2
|
||||
pymfy==0.6.0
|
||||
|
||||
# homeassistant.components.mochad
|
||||
pymochad==0.2.0
|
||||
|
@ -1015,14 +1015,23 @@ def mock_entity_platform(hass, platform_path, module):
|
||||
hue.light.
|
||||
"""
|
||||
domain, platform_name = platform_path.split(".")
|
||||
integration_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {})
|
||||
mock_platform(hass, f"{platform_name}.{domain}", module)
|
||||
|
||||
|
||||
def mock_platform(hass, platform_path, module=None):
|
||||
"""Mock a platform.
|
||||
|
||||
platform_path is in form hue.config_flow.
|
||||
"""
|
||||
domain, platform_name = platform_path.split(".")
|
||||
integration_cache = hass.data.setdefault(loader.DATA_INTEGRATIONS, {})
|
||||
module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {})
|
||||
|
||||
if platform_name not in integration_cache:
|
||||
mock_integration(hass, MockModule(platform_name))
|
||||
if domain not in integration_cache:
|
||||
mock_integration(hass, MockModule(domain))
|
||||
|
||||
_LOGGER.info("Adding mock integration platform: %s", platform_path)
|
||||
module_cache["{}.{}".format(platform_name, domain)] = module
|
||||
module_cache[platform_path] = module or Mock()
|
||||
|
||||
|
||||
def async_capture_events(hass, event_name):
|
||||
|
@ -1,19 +1,35 @@
|
||||
"""Tests for the Somfy config flow."""
|
||||
import asyncio
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import patch
|
||||
|
||||
from pymfy.api.somfy_api import SomfyApi
|
||||
import pytest
|
||||
|
||||
from homeassistant import data_entry_flow
|
||||
from homeassistant import data_entry_flow, setup, config_entries
|
||||
from homeassistant.components.somfy import config_flow, DOMAIN
|
||||
from homeassistant.components.somfy.config_flow import register_flow_implementation
|
||||
from tests.common import MockConfigEntry, mock_coro
|
||||
from homeassistant.helpers import config_entry_oauth2_flow
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
|
||||
CLIENT_SECRET_VALUE = "5678"
|
||||
|
||||
CLIENT_ID_VALUE = "1234"
|
||||
|
||||
AUTH_URL = "http://somfy.com"
|
||||
|
||||
@pytest.fixture()
|
||||
async def mock_impl(hass):
|
||||
"""Mock implementation."""
|
||||
await setup.async_setup_component(hass, "http", {})
|
||||
|
||||
impl = config_entry_oauth2_flow.LocalOAuth2Implementation(
|
||||
hass,
|
||||
DOMAIN,
|
||||
CLIENT_ID_VALUE,
|
||||
CLIENT_SECRET_VALUE,
|
||||
"https://accounts.somfy.com/oauth/oauth/v2/auth",
|
||||
"https://accounts.somfy.com/oauth/oauth/v2/token",
|
||||
)
|
||||
config_flow.SomfyFlowHandler.async_register_implementation(hass, impl)
|
||||
return impl
|
||||
|
||||
|
||||
async def test_abort_if_no_configuration(hass):
|
||||
@ -30,47 +46,84 @@ async def test_abort_if_existing_entry(hass):
|
||||
flow = config_flow.SomfyFlowHandler()
|
||||
flow.hass = hass
|
||||
MockConfigEntry(domain=DOMAIN).add_to_hass(hass)
|
||||
result = await flow.async_step_import()
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT
|
||||
assert result["reason"] == "already_setup"
|
||||
|
||||
result = await flow.async_step_user()
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT
|
||||
assert result["reason"] == "already_setup"
|
||||
|
||||
|
||||
async def test_full_flow(hass):
|
||||
"""Check classic use case."""
|
||||
hass.data[DOMAIN] = {}
|
||||
register_flow_implementation(hass, CLIENT_ID_VALUE, CLIENT_SECRET_VALUE)
|
||||
flow = config_flow.SomfyFlowHandler()
|
||||
flow.hass = hass
|
||||
hass.config.api = Mock(base_url="https://example.com")
|
||||
flow._get_authorization_url = Mock(return_value=mock_coro((AUTH_URL, "state")))
|
||||
result = await flow.async_step_import()
|
||||
async def test_full_flow(hass, aiohttp_client, aioclient_mock):
|
||||
"""Check full flow."""
|
||||
assert await setup.async_setup_component(
|
||||
hass,
|
||||
"somfy",
|
||||
{
|
||||
"somfy": {
|
||||
"client_id": CLIENT_ID_VALUE,
|
||||
"client_secret": CLIENT_SECRET_VALUE,
|
||||
},
|
||||
"http": {"base_url": "https://example.com"},
|
||||
},
|
||||
)
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
"somfy", context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
|
||||
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP
|
||||
assert result["url"] == AUTH_URL
|
||||
result = await flow.async_step_auth("my_super_code")
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP_DONE
|
||||
assert result["step_id"] == "creation"
|
||||
assert flow.code == "my_super_code"
|
||||
with patch.object(
|
||||
SomfyApi, "request_token", return_value={"access_token": "super_token"}
|
||||
):
|
||||
result = await flow.async_step_creation()
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||
assert result["data"]["refresh_args"] == {
|
||||
"client_id": CLIENT_ID_VALUE,
|
||||
"client_secret": CLIENT_SECRET_VALUE,
|
||||
assert result["url"] == (
|
||||
"https://accounts.somfy.com/oauth/oauth/v2/auth"
|
||||
f"?response_type=code&client_id={CLIENT_ID_VALUE}"
|
||||
"&redirect_uri=https://example.com/auth/external/callback"
|
||||
f"&state={state}"
|
||||
)
|
||||
|
||||
client = await aiohttp_client(hass.http.app)
|
||||
resp = await client.get(f"/auth/external/callback?code=abcd&state={state}")
|
||||
assert resp.status == 200
|
||||
assert resp.headers["content-type"] == "text/html; charset=utf-8"
|
||||
|
||||
aioclient_mock.post(
|
||||
"https://accounts.somfy.com/oauth/oauth/v2/token",
|
||||
json={
|
||||
"refresh_token": "mock-refresh-token",
|
||||
"access_token": "mock-access-token",
|
||||
"type": "Bearer",
|
||||
"expires_in": 60,
|
||||
},
|
||||
)
|
||||
|
||||
with patch("homeassistant.components.somfy.api.ConfigEntrySomfyApi"):
|
||||
result = await hass.config_entries.flow.async_configure(result["flow_id"])
|
||||
|
||||
assert result["data"]["auth_implementation"] == "somfy"
|
||||
|
||||
result["data"]["token"].pop("expires_at")
|
||||
assert result["data"]["token"] == {
|
||||
"refresh_token": "mock-refresh-token",
|
||||
"access_token": "mock-access-token",
|
||||
"type": "Bearer",
|
||||
"expires_in": 60,
|
||||
}
|
||||
assert result["title"] == "Somfy"
|
||||
assert result["data"]["token"] == {"access_token": "super_token"}
|
||||
|
||||
assert "somfy" in hass.config.components
|
||||
entry = hass.config_entries.async_entries("somfy")[0]
|
||||
assert entry.state == config_entries.ENTRY_STATE_LOADED
|
||||
|
||||
assert await hass.config_entries.async_unload(entry.entry_id)
|
||||
assert entry.state == config_entries.ENTRY_STATE_NOT_LOADED
|
||||
|
||||
|
||||
async def test_abort_if_authorization_timeout(hass):
|
||||
async def test_abort_if_authorization_timeout(hass, mock_impl):
|
||||
"""Check Somfy authorization timeout."""
|
||||
flow = config_flow.SomfyFlowHandler()
|
||||
flow.hass = hass
|
||||
flow._get_authorization_url = Mock(side_effect=asyncio.TimeoutError)
|
||||
result = await flow.async_step_auth()
|
||||
|
||||
with patch.object(
|
||||
mock_impl, "async_generate_authorize_url", side_effect=asyncio.TimeoutError
|
||||
):
|
||||
result = await flow.async_step_user()
|
||||
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT
|
||||
assert result["reason"] == "authorize_url_timeout"
|
||||
|
266
tests/helpers/test_config_entry_oauth2_flow.py
Normal file
266
tests/helpers/test_config_entry_oauth2_flow.py
Normal file
@ -0,0 +1,266 @@
|
||||
"""Tests for the Somfy config flow."""
|
||||
import asyncio
|
||||
import logging
|
||||
from unittest.mock import patch
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant import data_entry_flow, setup, config_entries
|
||||
from homeassistant.helpers import config_entry_oauth2_flow
|
||||
|
||||
from tests.common import mock_platform, MockConfigEntry
|
||||
|
||||
TEST_DOMAIN = "oauth2_test"
|
||||
CLIENT_SECRET = "5678"
|
||||
CLIENT_ID = "1234"
|
||||
REFRESH_TOKEN = "mock-refresh-token"
|
||||
ACCESS_TOKEN_1 = "mock-access-token-1"
|
||||
ACCESS_TOKEN_2 = "mock-access-token-2"
|
||||
AUTHORIZE_URL = "https://example.como/auth/authorize"
|
||||
TOKEN_URL = "https://example.como/auth/token"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def local_impl(hass):
|
||||
"""Local implementation."""
|
||||
assert await setup.async_setup_component(hass, "http", {})
|
||||
return config_entry_oauth2_flow.LocalOAuth2Implementation(
|
||||
hass, TEST_DOMAIN, CLIENT_ID, CLIENT_SECRET, AUTHORIZE_URL, TOKEN_URL
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def flow_handler(hass):
|
||||
"""Return a registered config flow."""
|
||||
|
||||
mock_platform(hass, f"{TEST_DOMAIN}.config_flow")
|
||||
|
||||
class TestFlowHandler(config_entry_oauth2_flow.AbstractOAuth2FlowHandler):
|
||||
"""Test flow handler."""
|
||||
|
||||
DOMAIN = TEST_DOMAIN
|
||||
|
||||
@property
|
||||
def logger(self) -> logging.Logger:
|
||||
"""Return logger."""
|
||||
return logging.getLogger(__name__)
|
||||
|
||||
@property
|
||||
def extra_authorize_data(self) -> dict:
|
||||
"""Extra data that needs to be appended to the authorize url."""
|
||||
return {"scope": "read write"}
|
||||
|
||||
with patch.dict(config_entries.HANDLERS, {TEST_DOMAIN: TestFlowHandler}):
|
||||
yield TestFlowHandler
|
||||
|
||||
|
||||
class MockOAuth2Implementation(config_entry_oauth2_flow.AbstractOAuth2Implementation):
|
||||
"""Mock implementation for testing."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Name of the implementation."""
|
||||
return "Mock"
|
||||
|
||||
@property
|
||||
def domain(self) -> str:
|
||||
"""Domain that is providing the implementation."""
|
||||
return "test"
|
||||
|
||||
async def async_generate_authorize_url(self, flow_id: str) -> str:
|
||||
"""Generate a url for the user to authorize."""
|
||||
return "http://example.com/auth"
|
||||
|
||||
async def async_resolve_external_data(self, external_data) -> dict:
|
||||
"""Resolve external data to tokens."""
|
||||
return external_data
|
||||
|
||||
async def _async_refresh_token(self, token: dict) -> dict:
|
||||
"""Refresh a token."""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def test_inherit_enforces_domain_set():
|
||||
"""Test we enforce setting DOMAIN."""
|
||||
|
||||
class TestFlowHandler(config_entry_oauth2_flow.AbstractOAuth2FlowHandler):
|
||||
"""Test flow handler."""
|
||||
|
||||
@property
|
||||
def logger(self) -> logging.Logger:
|
||||
"""Return logger."""
|
||||
return logging.getLogger(__name__)
|
||||
|
||||
with patch.dict(config_entries.HANDLERS, {TEST_DOMAIN: TestFlowHandler}):
|
||||
with pytest.raises(TypeError):
|
||||
TestFlowHandler()
|
||||
|
||||
|
||||
async def test_abort_if_no_implementation(hass, flow_handler):
|
||||
"""Check flow abort when no implementations."""
|
||||
flow = flow_handler()
|
||||
flow.hass = hass
|
||||
result = await flow.async_step_user()
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT
|
||||
assert result["reason"] == "missing_configuration"
|
||||
|
||||
|
||||
async def test_abort_if_authorization_timeout(hass, flow_handler, local_impl):
|
||||
"""Check timeout generating authorization url."""
|
||||
flow_handler.async_register_implementation(hass, local_impl)
|
||||
|
||||
flow = flow_handler()
|
||||
flow.hass = hass
|
||||
|
||||
with patch.object(
|
||||
local_impl, "async_generate_authorize_url", side_effect=asyncio.TimeoutError
|
||||
):
|
||||
result = await flow.async_step_user()
|
||||
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT
|
||||
assert result["reason"] == "authorize_url_timeout"
|
||||
|
||||
|
||||
async def test_full_flow(
|
||||
hass, flow_handler, local_impl, aiohttp_client, aioclient_mock
|
||||
):
|
||||
"""Check full flow."""
|
||||
hass.config.api.base_url = "https://example.com"
|
||||
flow_handler.async_register_implementation(hass, local_impl)
|
||||
config_entry_oauth2_flow.async_register_implementation(
|
||||
hass, TEST_DOMAIN, MockOAuth2Implementation()
|
||||
)
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
TEST_DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
|
||||
assert result["step_id"] == "pick_implementation"
|
||||
|
||||
# Pick implementation
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"], user_input={"implementation": TEST_DOMAIN}
|
||||
)
|
||||
|
||||
state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
|
||||
|
||||
assert result["type"] == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP
|
||||
assert result["url"] == (
|
||||
f"{AUTHORIZE_URL}?response_type=code&client_id={CLIENT_ID}"
|
||||
"&redirect_uri=https://example.com/auth/external/callback"
|
||||
f"&state={state}&scope=read+write"
|
||||
)
|
||||
|
||||
client = await aiohttp_client(hass.http.app)
|
||||
resp = await client.get(f"/auth/external/callback?code=abcd&state={state}")
|
||||
assert resp.status == 200
|
||||
assert resp.headers["content-type"] == "text/html; charset=utf-8"
|
||||
|
||||
aioclient_mock.post(
|
||||
TOKEN_URL,
|
||||
json={
|
||||
"refresh_token": REFRESH_TOKEN,
|
||||
"access_token": ACCESS_TOKEN_1,
|
||||
"type": "bearer",
|
||||
"expires_in": 60,
|
||||
},
|
||||
)
|
||||
|
||||
result = await hass.config_entries.flow.async_configure(result["flow_id"])
|
||||
|
||||
assert result["data"]["auth_implementation"] == TEST_DOMAIN
|
||||
|
||||
result["data"]["token"].pop("expires_at")
|
||||
assert result["data"]["token"] == {
|
||||
"refresh_token": REFRESH_TOKEN,
|
||||
"access_token": ACCESS_TOKEN_1,
|
||||
"type": "bearer",
|
||||
"expires_in": 60,
|
||||
}
|
||||
|
||||
entry = hass.config_entries.async_entries(TEST_DOMAIN)[0]
|
||||
|
||||
assert (
|
||||
await config_entry_oauth2_flow.async_get_config_entry_implementation(
|
||||
hass, entry
|
||||
)
|
||||
is local_impl
|
||||
)
|
||||
|
||||
|
||||
async def test_local_refresh_token(hass, local_impl, aioclient_mock):
|
||||
"""Test we can refresh token."""
|
||||
aioclient_mock.post(
|
||||
TOKEN_URL, json={"access_token": ACCESS_TOKEN_2, "expires_in": 100}
|
||||
)
|
||||
|
||||
new_tokens = await local_impl.async_refresh_token(
|
||||
{
|
||||
"refresh_token": REFRESH_TOKEN,
|
||||
"access_token": ACCESS_TOKEN_1,
|
||||
"type": "bearer",
|
||||
"expires_in": 60,
|
||||
}
|
||||
)
|
||||
new_tokens.pop("expires_at")
|
||||
|
||||
assert new_tokens == {
|
||||
"refresh_token": REFRESH_TOKEN,
|
||||
"access_token": ACCESS_TOKEN_2,
|
||||
"type": "bearer",
|
||||
"expires_in": 100,
|
||||
}
|
||||
|
||||
assert len(aioclient_mock.mock_calls) == 1
|
||||
assert aioclient_mock.mock_calls[0][2] == {
|
||||
"client_id": CLIENT_ID,
|
||||
"client_secret": CLIENT_SECRET,
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": REFRESH_TOKEN,
|
||||
}
|
||||
|
||||
|
||||
async def test_oauth_session(hass, flow_handler, local_impl, aioclient_mock):
|
||||
"""Test the OAuth2 session helper."""
|
||||
flow_handler.async_register_implementation(hass, local_impl)
|
||||
|
||||
aioclient_mock.post(
|
||||
TOKEN_URL, json={"access_token": ACCESS_TOKEN_2, "expires_in": 100}
|
||||
)
|
||||
|
||||
aioclient_mock.post("https://example.com", status=201)
|
||||
|
||||
config_entry = MockConfigEntry(
|
||||
domain=TEST_DOMAIN,
|
||||
data={
|
||||
"auth_implementation": TEST_DOMAIN,
|
||||
"token": {
|
||||
"refresh_token": REFRESH_TOKEN,
|
||||
"access_token": ACCESS_TOKEN_1,
|
||||
"expires_in": 10,
|
||||
"expires_at": 0, # Forces a refresh,
|
||||
"token_type": "bearer",
|
||||
"random_other_data": "should_stay",
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
now = time.time()
|
||||
session = config_entry_oauth2_flow.OAuth2Session(hass, config_entry, local_impl)
|
||||
resp = await session.async_request("post", "https://example.com")
|
||||
assert resp.status == 201
|
||||
|
||||
# Refresh token, make request
|
||||
assert len(aioclient_mock.mock_calls) == 2
|
||||
|
||||
assert (
|
||||
aioclient_mock.mock_calls[1][3]["authorization"] == f"Bearer {ACCESS_TOKEN_2}"
|
||||
)
|
||||
|
||||
assert config_entry.data["token"]["refresh_token"] == REFRESH_TOKEN
|
||||
assert config_entry.data["token"]["access_token"] == ACCESS_TOKEN_2
|
||||
assert config_entry.data["token"]["expires_in"] == 100
|
||||
assert config_entry.data["token"]["random_other_data"] == "should_stay"
|
||||
assert round(config_entry.data["token"]["expires_at"] - now) == 100
|
Loading…
x
Reference in New Issue
Block a user