diff --git a/homeassistant/bootstrap.py b/homeassistant/bootstrap.py
index e399205ec70..6118f4f2bd7 100644
--- a/homeassistant/bootstrap.py
+++ b/homeassistant/bootstrap.py
@@ -260,7 +260,7 @@ def _get_domains(hass: core.HomeAssistant, config: Dict[str, Any]) -> Set[str]:
domains = set(key.split(" ")[0] for key in config.keys() if key != core.DOMAIN)
# Add config entry domains
- domains.update(hass.config_entries.async_domains()) # type: ignore
+ domains.update(hass.config_entries.async_domains())
# Make sure the Hass.io component is loaded
if "HASSIO" in os.environ:
diff --git a/homeassistant/components/somfy/__init__.py b/homeassistant/components/somfy/__init__.py
index 2c7c71d7a69..cd5960bf6b1 100644
--- a/homeassistant/components/somfy/__init__.py
+++ b/homeassistant/components/somfy/__init__.py
@@ -4,21 +4,21 @@ Support for Somfy hubs.
For more details about this component, please refer to the documentation at
https://home-assistant.io/integrations/somfy/
"""
+import asyncio
import logging
from datetime import timedelta
-from functools import partial
import voluptuous as vol
-import homeassistant.helpers.config_validation as cv
-from homeassistant import config_entries
+from homeassistant.helpers import config_validation as cv, config_entry_oauth2_flow
from homeassistant.components.somfy import config_flow
from homeassistant.config_entries import ConfigEntry
-from homeassistant.const import CONF_TOKEN
from homeassistant.helpers.entity import Entity
from homeassistant.helpers.typing import HomeAssistantType
from homeassistant.util import Throttle
+from . import api
+
API = "api"
DEVICES = "devices"
@@ -52,19 +52,21 @@ SOMFY_COMPONENTS = ["cover"]
async def async_setup(hass, config):
"""Set up the Somfy component."""
+ hass.data[DOMAIN] = {}
+
if DOMAIN not in config:
return True
- hass.data[DOMAIN] = {}
-
- config_flow.register_flow_implementation(
- hass, config[DOMAIN][CONF_CLIENT_ID], config[DOMAIN][CONF_CLIENT_SECRET]
- )
-
- hass.async_create_task(
- hass.config_entries.flow.async_init(
- DOMAIN, context={"source": config_entries.SOURCE_IMPORT}
- )
+ config_flow.SomfyFlowHandler.async_register_implementation(
+ hass,
+ config_entry_oauth2_flow.LocalOAuth2Implementation(
+ hass,
+ DOMAIN,
+ config[DOMAIN][CONF_CLIENT_ID],
+ config[DOMAIN][CONF_CLIENT_SECRET],
+ "https://accounts.somfy.com/oauth/oauth/v2/auth",
+ "https://accounts.somfy.com/oauth/oauth/v2/token",
+ ),
)
return True
@@ -72,25 +74,18 @@ async def async_setup(hass, config):
async def async_setup_entry(hass: HomeAssistantType, entry: ConfigEntry):
"""Set up Somfy from a config entry."""
-
- def token_saver(token):
- _LOGGER.debug("Saving updated token")
- entry.data[CONF_TOKEN] = token
- update_entry = partial(
- hass.config_entries.async_update_entry, data={**entry.data}
+ # Backwards compat
+ if "auth_implementation" not in entry.data:
+ hass.config_entries.async_update_entry(
+ entry, data={**entry.data, "auth_implementation": DOMAIN}
)
- hass.add_job(update_entry, entry)
- # Force token update.
- from pymfy.api.somfy_api import SomfyApi
-
- hass.data[DOMAIN][API] = SomfyApi(
- entry.data["refresh_args"]["client_id"],
- entry.data["refresh_args"]["client_secret"],
- token=entry.data[CONF_TOKEN],
- token_updater=token_saver,
+ implementation = await config_entry_oauth2_flow.async_get_config_entry_implementation(
+ hass, entry
)
+ hass.data[DOMAIN][API] = api.ConfigEntrySomfyApi(hass, entry, implementation)
+
await update_all_devices(hass)
for component in SOMFY_COMPONENTS:
@@ -104,16 +99,22 @@ async def async_setup_entry(hass: HomeAssistantType, entry: ConfigEntry):
async def async_unload_entry(hass: HomeAssistantType, entry: ConfigEntry):
"""Unload a config entry."""
hass.data[DOMAIN].pop(API, None)
+ await asyncio.gather(
+ *[
+ hass.config_entries.async_forward_entry_unload(entry, component)
+ for component in SOMFY_COMPONENTS
+ ]
+ )
return True
class SomfyEntity(Entity):
"""Representation of a generic Somfy device."""
- def __init__(self, device, api):
+ def __init__(self, device, somfy_api):
"""Initialize the Somfy device."""
self.device = device
- self.api = api
+ self.api = somfy_api
@property
def unique_id(self):
diff --git a/homeassistant/components/somfy/api.py b/homeassistant/components/somfy/api.py
new file mode 100644
index 00000000000..3e7bcf9deb4
--- /dev/null
+++ b/homeassistant/components/somfy/api.py
@@ -0,0 +1,55 @@
+"""API for Somfy bound to HASS OAuth."""
+from asyncio import run_coroutine_threadsafe
+from functools import partial
+
+import requests
+from pymfy.api import somfy_api
+
+from homeassistant import core, config_entries
+from homeassistant.helpers import config_entry_oauth2_flow
+
+
+class ConfigEntrySomfyApi(somfy_api.AbstractSomfyApi):
+ """Provide a Somfy API tied into an OAuth2 based config entry."""
+
+ def __init__(
+ self,
+ hass: core.HomeAssistant,
+ config_entry: config_entries.ConfigEntry,
+ implementation: config_entry_oauth2_flow.AbstractOAuth2Implementation,
+ ):
+ """Initialize the Config Entry Somfy API."""
+ self.hass = hass
+ self.config_entry = config_entry
+ self.session = config_entry_oauth2_flow.OAuth2Session(
+ hass, config_entry, implementation
+ )
+
+ def get(self, path):
+ """Fetch a URL from the Somfy API."""
+ return run_coroutine_threadsafe(
+ self._request("get", path), self.hass.loop
+ ).result()
+
+ def post(self, path, *, json):
+ """Post data to the Somfy API."""
+ return run_coroutine_threadsafe(
+ self._request("post", path, json=json), self.hass.loop
+ ).result()
+
+ async def _request(self, method, path, **kwargs):
+ """Make a request."""
+ await self.session.async_ensure_token_valid()
+
+ return await self.hass.async_add_executor_job(
+ partial(
+ requests.request,
+ method,
+ f"{self.base_url}{path}",
+ **kwargs,
+ headers={
+ **kwargs.get("headers", {}),
+ "authorization": f"Bearer {self.config_entry.data['token']['access_token']}",
+ },
+ )
+ )
diff --git a/homeassistant/components/somfy/config_flow.py b/homeassistant/components/somfy/config_flow.py
index 9f3c58c8ffb..cb180d4e247 100644
--- a/homeassistant/components/somfy/config_flow.py
+++ b/homeassistant/components/somfy/config_flow.py
@@ -1,141 +1,28 @@
"""Config flow for Somfy."""
-import asyncio
import logging
-import async_timeout
-
from homeassistant import config_entries
-from homeassistant.components.http import HomeAssistantView
-from homeassistant.core import callback
-from .const import CLIENT_ID, CLIENT_SECRET, DOMAIN
-
-AUTH_CALLBACK_PATH = "/auth/somfy/callback"
-AUTH_CALLBACK_NAME = "auth:somfy:callback"
+from homeassistant.helpers import config_entry_oauth2_flow
+from .const import DOMAIN
_LOGGER = logging.getLogger(__name__)
-@callback
-def register_flow_implementation(hass, client_id, client_secret):
- """Register a flow implementation.
+@config_entries.HANDLERS.register(DOMAIN)
+class SomfyFlowHandler(config_entry_oauth2_flow.AbstractOAuth2FlowHandler):
+ """Config flow to handle Somfy OAuth2 authentication."""
- client_id: Client id.
- client_secret: Client secret.
- """
- hass.data[DOMAIN][CLIENT_ID] = client_id
- hass.data[DOMAIN][CLIENT_SECRET] = client_secret
-
-
-@config_entries.HANDLERS.register("somfy")
-class SomfyFlowHandler(config_entries.ConfigFlow):
- """Handle a config flow."""
-
- VERSION = 1
+ DOMAIN = DOMAIN
CONNECTION_CLASS = config_entries.CONN_CLASS_CLOUD_POLL
- def __init__(self):
- """Instantiate config flow."""
- self.code = None
-
- async def async_step_import(self, user_input=None):
- """Handle external yaml configuration."""
- if self.hass.config_entries.async_entries(DOMAIN):
- return self.async_abort(reason="already_setup")
- return await self.async_step_auth()
+ @property
+ def logger(self) -> logging.Logger:
+ """Return logger."""
+ return logging.getLogger(__name__)
async def async_step_user(self, user_input=None):
"""Handle a flow start."""
if self.hass.config_entries.async_entries(DOMAIN):
return self.async_abort(reason="already_setup")
- if DOMAIN not in self.hass.data:
- return self.async_abort(reason="missing_configuration")
-
- return await self.async_step_auth()
-
- async def async_step_auth(self, user_input=None):
- """Create an entry for auth."""
- # Flow has been triggered from Somfy website
- if user_input:
- return await self.async_step_code(user_input)
-
- try:
- with async_timeout.timeout(10):
- url, _ = await self._get_authorization_url()
- except asyncio.TimeoutError:
- return self.async_abort(reason="authorize_url_timeout")
-
- return self.async_external_step(step_id="auth", url=url)
-
- async def _get_authorization_url(self):
- """Get Somfy authorization url."""
- from pymfy.api.somfy_api import SomfyApi
-
- client_id = self.hass.data[DOMAIN][CLIENT_ID]
- client_secret = self.hass.data[DOMAIN][CLIENT_SECRET]
- redirect_uri = f"{self.hass.config.api.base_url}{AUTH_CALLBACK_PATH}"
- api = SomfyApi(client_id, client_secret, redirect_uri)
-
- self.hass.http.register_view(SomfyAuthCallbackView())
- # Thanks to the state, we can forward the flow id to Somfy that will
- # add it in the callback.
- return await self.hass.async_add_executor_job(
- api.get_authorization_url, self.flow_id
- )
-
- async def async_step_code(self, code):
- """Received code for authentication."""
- self.code = code
- return self.async_external_step_done(next_step_id="creation")
-
- async def async_step_creation(self, user_input=None):
- """Create Somfy api and entries."""
- client_id = self.hass.data[DOMAIN][CLIENT_ID]
- client_secret = self.hass.data[DOMAIN][CLIENT_SECRET]
- code = self.code
- from pymfy.api.somfy_api import SomfyApi
-
- redirect_uri = f"{self.hass.config.api.base_url}{AUTH_CALLBACK_PATH}"
- api = SomfyApi(client_id, client_secret, redirect_uri)
- token = await self.hass.async_add_executor_job(api.request_token, None, code)
- _LOGGER.info("Successfully authenticated Somfy")
- return self.async_create_entry(
- title="Somfy",
- data={
- "token": token,
- "refresh_args": {
- "client_id": client_id,
- "client_secret": client_secret,
- },
- },
- )
-
-
-class SomfyAuthCallbackView(HomeAssistantView):
- """Somfy Authorization Callback View."""
-
- requires_auth = False
- url = AUTH_CALLBACK_PATH
- name = AUTH_CALLBACK_NAME
-
- @staticmethod
- async def get(request):
- """Receive authorization code."""
- from aiohttp import web_response
-
- if "code" not in request.query or "state" not in request.query:
- return web_response.Response(
- text="Missing code or state parameter in " + request.url
- )
-
- hass = request.app["hass"]
- hass.async_create_task(
- hass.config_entries.flow.async_configure(
- flow_id=request.query["state"], user_input=request.query["code"]
- )
- )
-
- return web_response.Response(
- headers={"content-type": "text/html"},
- text="",
- )
+ return await super().async_step_user(user_input)
diff --git a/homeassistant/components/somfy/const.py b/homeassistant/components/somfy/const.py
index 99fafb71bff..8765e37e6d6 100644
--- a/homeassistant/components/somfy/const.py
+++ b/homeassistant/components/somfy/const.py
@@ -1,5 +1,3 @@
"""Define constants for the Somfy component."""
DOMAIN = "somfy"
-CLIENT_ID = "client_id"
-CLIENT_SECRET = "client_secret"
diff --git a/homeassistant/components/somfy/manifest.json b/homeassistant/components/somfy/manifest.json
index 83b50684fda..a34023f76ff 100644
--- a/homeassistant/components/somfy/manifest.json
+++ b/homeassistant/components/somfy/manifest.json
@@ -3,11 +3,7 @@
"name": "Somfy Open API",
"config_flow": true,
"documentation": "https://www.home-assistant.io/integrations/somfy",
- "dependencies": [],
- "codeowners": [
- "@tetienne"
- ],
- "requirements": [
- "pymfy==0.5.2"
- ]
-}
\ No newline at end of file
+ "dependencies": ["http"],
+ "codeowners": ["@tetienne"],
+ "requirements": ["pymfy==0.6.0"]
+}
diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py
index 8a40cff1bd5..f8c7c7a9da1 100644
--- a/homeassistant/config_entries.py
+++ b/homeassistant/config_entries.py
@@ -337,7 +337,7 @@ class ConfigEntry:
return False
if result:
# pylint: disable=protected-access
- hass.config_entries._async_schedule_save() # type: ignore
+ hass.config_entries._async_schedule_save()
return result
except Exception: # pylint: disable=broad-except
_LOGGER.exception(
diff --git a/homeassistant/core.py b/homeassistant/core.py
index 90d197906cb..ec11b14edaa 100644
--- a/homeassistant/core.py
+++ b/homeassistant/core.py
@@ -77,7 +77,8 @@ from homeassistant.util.unit_system import ( # NOQA
# Typing imports that create a circular dependency
# pylint: disable=using-constant-test
if TYPE_CHECKING:
- from homeassistant.config_entries import ConfigEntries # noqa
+ from homeassistant.config_entries import ConfigEntries
+ from homeassistant.components.http import HomeAssistantHTTP
# pylint: disable=invalid-name
T = TypeVar("T")
@@ -162,6 +163,9 @@ class CoreState(enum.Enum):
class HomeAssistant:
"""Root object of the Home Assistant home automation."""
+ http: "HomeAssistantHTTP" = None # type: ignore
+ config_entries: "ConfigEntries" = None # type: ignore
+
def __init__(self, loop: Optional[asyncio.events.AbstractEventLoop] = None) -> None:
"""Initialize new Home Assistant object."""
self.loop: asyncio.events.AbstractEventLoop = (loop or asyncio.get_event_loop())
@@ -186,9 +190,6 @@ class HomeAssistant:
self.data: dict = {}
self.state = CoreState.not_running
self.exit_code = 0
- self.config_entries: Optional[
- ConfigEntries # pylint: disable=used-before-assignment
- ] = None
# If not None, use to signal end-of-loop
self._stopped: Optional[asyncio.Event] = None
diff --git a/homeassistant/data_entry_flow.py b/homeassistant/data_entry_flow.py
index 0bc27498f76..c06c69d9213 100644
--- a/homeassistant/data_entry_flow.py
+++ b/homeassistant/data_entry_flow.py
@@ -168,7 +168,7 @@ class FlowHandler:
"""Handle the configuration flow of a component."""
# Set by flow manager
- flow_id: Optional[str] = None
+ flow_id: str = None # type: ignore
hass: Optional[HomeAssistant] = None
handler: Optional[Hashable] = None
cur_step: Optional[Dict[str, str]] = None
diff --git a/homeassistant/helpers/config_entry_oauth2_flow.py b/homeassistant/helpers/config_entry_oauth2_flow.py
new file mode 100644
index 00000000000..043a28cac27
--- /dev/null
+++ b/homeassistant/helpers/config_entry_oauth2_flow.py
@@ -0,0 +1,420 @@
+"""Config Flow using OAuth2.
+
+This module exists of the following parts:
+ - OAuth2 config flow which supports multiple OAuth2 implementations
+ - OAuth2 implementation that works with local provided client ID/secret
+
+"""
+import asyncio
+from abc import ABCMeta, ABC, abstractmethod
+import logging
+from typing import Optional, Any, Dict, cast
+import time
+
+import async_timeout
+from aiohttp import web, client
+import jwt
+import voluptuous as vol
+from yarl import URL
+
+from homeassistant.auth.util import generate_secret
+from homeassistant.core import HomeAssistant, callback
+from homeassistant import config_entries
+from homeassistant.components.http import HomeAssistantView
+
+from .aiohttp_client import async_get_clientsession
+
+
+DATA_JWT_SECRET = "oauth2_jwt_secret"
+DATA_VIEW_REGISTERED = "oauth2_view_reg"
+DATA_IMPLEMENTATIONS = "oauth2_impl"
+AUTH_CALLBACK_PATH = "/auth/external/callback"
+
+
+class AbstractOAuth2Implementation(ABC):
+ """Base class to abstract OAuth2 authentication."""
+
+ @property
+ @abstractmethod
+ def name(self) -> str:
+ """Name of the implementation."""
+
+ @property
+ @abstractmethod
+ def domain(self) -> str:
+ """Domain that is providing the implementation."""
+
+ @abstractmethod
+ async def async_generate_authorize_url(self, flow_id: str) -> str:
+ """Generate a url for the user to authorize.
+
+ This step is called when a config flow is initialized. It should redirect the
+ user to the vendor website where they can authorize Home Assistant.
+
+ The implementation is responsible to get notified when the user is authorized
+ and pass this to the specified config flow. Do as little work as possible once
+ notified. You can do the work inside async_resolve_external_data. This will
+ give the best UX.
+
+ Pass external data in with:
+
+ ```python
+ await hass.config_entries.flow.async_configure(
+ flow_id=flow_id, user_input=external_data
+ )
+ ```
+ """
+
+ @abstractmethod
+ async def async_resolve_external_data(self, external_data: Any) -> dict:
+ """Resolve external data to tokens.
+
+ Turn the data that the implementation passed to the config flow as external
+ step data into tokens. These tokens will be stored as 'token' in the
+ config entry data.
+ """
+
+ async def async_refresh_token(self, token: dict) -> dict:
+ """Refresh a token and update expires info."""
+ new_token = await self._async_refresh_token(token)
+ new_token["expires_at"] = time.time() + new_token["expires_in"]
+ return new_token
+
+ @abstractmethod
+ async def _async_refresh_token(self, token: dict) -> dict:
+ """Refresh a token."""
+
+
+class LocalOAuth2Implementation(AbstractOAuth2Implementation):
+ """Local OAuth2 implementation."""
+
+ def __init__(
+ self,
+ hass: HomeAssistant,
+ domain: str,
+ client_id: str,
+ client_secret: str,
+ authorize_url: str,
+ token_url: str,
+ ):
+ """Initialize local auth implementation."""
+ self.hass = hass
+ self._domain = domain
+ self.client_id = client_id
+ self.client_secret = client_secret
+ self.authorize_url = authorize_url
+ self.token_url = token_url
+
+ @property
+ def name(self) -> str:
+ """Name of the implementation."""
+ return "Configuration.yaml"
+
+ @property
+ def domain(self) -> str:
+ """Domain providing the implementation."""
+ return self._domain
+
+ @property
+ def redirect_uri(self) -> str:
+ """Return the redirect uri."""
+ return f"{self.hass.config.api.base_url}{AUTH_CALLBACK_PATH}" # type: ignore
+
+ async def async_generate_authorize_url(self, flow_id: str) -> str:
+ """Generate a url for the user to authorize."""
+ return str(
+ URL(self.authorize_url).with_query(
+ {
+ "response_type": "code",
+ "client_id": self.client_id,
+ "redirect_uri": self.redirect_uri,
+ "state": _encode_jwt(self.hass, {"flow_id": flow_id}),
+ }
+ )
+ )
+
+ async def async_resolve_external_data(self, external_data: Any) -> dict:
+ """Resolve the authorization code to tokens."""
+ return await self._token_request(
+ {
+ "grant_type": "authorization_code",
+ "code": external_data,
+ "redirect_uri": self.redirect_uri,
+ }
+ )
+
+ async def _async_refresh_token(self, token: dict) -> dict:
+ """Refresh tokens."""
+ new_token = await self._token_request(
+ {
+ "grant_type": "refresh_token",
+ "client_id": self.client_id,
+ "refresh_token": token["refresh_token"],
+ }
+ )
+ return {**token, **new_token}
+
+ async def _token_request(self, data: dict) -> dict:
+ """Make a token request."""
+ session = async_get_clientsession(self.hass)
+
+ data["client_id"] = self.client_id
+
+ if self.client_secret is not None:
+ data["client_secret"] = self.client_secret
+
+ resp = await session.post(self.token_url, data=data)
+ resp.raise_for_status()
+ return cast(dict, await resp.json())
+
+
+class AbstractOAuth2FlowHandler(config_entries.ConfigFlow, metaclass=ABCMeta):
+ """Handle a config flow."""
+
+ DOMAIN = ""
+
+ VERSION = 1
+ CONNECTION_CLASS = config_entries.CONN_CLASS_UNKNOWN
+
+ def __init__(self) -> None:
+ """Instantiate config flow."""
+ if self.DOMAIN == "":
+ raise TypeError(
+ f"Can't instantiate class {self.__class__.__name__} without DOMAIN being set"
+ )
+
+ self.external_data: Any = None
+ self.flow_impl: AbstractOAuth2Implementation = None # type: ignore
+
+ @property
+ @abstractmethod
+ def logger(self) -> logging.Logger:
+ """Return logger."""
+
+ @property
+ def extra_authorize_data(self) -> dict:
+ """Extra data that needs to be appended to the authorize url."""
+ return {}
+
+ async def async_step_pick_implementation(self, user_input: dict = None) -> dict:
+ """Handle a flow start."""
+ assert self.hass
+ implementations = await async_get_implementations(self.hass, self.DOMAIN)
+
+ if user_input is not None:
+ self.flow_impl = implementations[user_input["implementation"]]
+ return await self.async_step_auth()
+
+ if not implementations:
+ return self.async_abort(reason="missing_configuration")
+
+ if len(implementations) == 1:
+ # Pick first implementation as we have only one.
+ self.flow_impl = list(implementations.values())[0]
+ return await self.async_step_auth()
+
+ return self.async_show_form(
+ step_id="pick_implementation",
+ data_schema=vol.Schema(
+ {
+ vol.Required(
+ "implementation", default=list(implementations.keys())[0]
+ ): vol.In({key: impl.name for key, impl in implementations.items()})
+ }
+ ),
+ )
+
+ async def async_step_auth(self, user_input: dict = None) -> dict:
+ """Create an entry for auth."""
+ # Flow has been triggered by external data
+ if user_input:
+ self.external_data = user_input
+ return self.async_external_step_done(next_step_id="creation")
+
+ try:
+ with async_timeout.timeout(10):
+ url = await self.flow_impl.async_generate_authorize_url(self.flow_id)
+ except asyncio.TimeoutError:
+ return self.async_abort(reason="authorize_url_timeout")
+
+ url = str(URL(url).update_query(self.extra_authorize_data))
+
+ return self.async_external_step(step_id="auth", url=url)
+
+ async def async_step_creation(self, user_input: dict = None) -> dict:
+ """Create config entry from external data."""
+ token = await self.flow_impl.async_resolve_external_data(self.external_data)
+ token["expires_at"] = time.time() + token["expires_in"]
+
+ self.logger.info("Successfully authenticated")
+
+ return await self.async_oauth_create_entry(
+ {"auth_implementation": self.flow_impl.domain, "token": token}
+ )
+
+ async def async_oauth_create_entry(self, data: dict) -> dict:
+ """Create an entry for the flow.
+
+ Ok to override if you want to fetch extra info or even add another step.
+ """
+ return self.async_create_entry(title=self.flow_impl.name, data=data)
+
+ async_step_user = async_step_pick_implementation
+ async_step_ssdp = async_step_pick_implementation
+ async_step_zeroconf = async_step_pick_implementation
+ async_step_homekit = async_step_pick_implementation
+
+ @classmethod
+ def async_register_implementation(
+ cls, hass: HomeAssistant, local_impl: LocalOAuth2Implementation
+ ) -> None:
+ """Register a local implementation."""
+ async_register_implementation(hass, cls.DOMAIN, local_impl)
+
+
+@callback
+def async_register_implementation(
+ hass: HomeAssistant, domain: str, implementation: AbstractOAuth2Implementation
+) -> None:
+ """Register an OAuth2 flow implementation for an integration."""
+ if isinstance(implementation, LocalOAuth2Implementation) and not hass.data.get(
+ DATA_VIEW_REGISTERED, False
+ ):
+ hass.http.register_view(OAuth2AuthorizeCallbackView()) # type: ignore
+ hass.data[DATA_VIEW_REGISTERED] = True
+
+ implementations = hass.data.setdefault(DATA_IMPLEMENTATIONS, {})
+ implementations.setdefault(domain, {})[implementation.domain] = implementation
+
+
+async def async_get_implementations(
+ hass: HomeAssistant, domain: str
+) -> Dict[str, AbstractOAuth2Implementation]:
+ """Return OAuth2 implementations for specified domain."""
+ return cast(
+ Dict[str, AbstractOAuth2Implementation],
+ hass.data.setdefault(DATA_IMPLEMENTATIONS, {}).get(domain, {}),
+ )
+
+
+async def async_get_config_entry_implementation(
+ hass: HomeAssistant, config_entry: config_entries.ConfigEntry
+) -> AbstractOAuth2Implementation:
+ """Return the implementation for this config entry."""
+ implementations = await async_get_implementations(hass, config_entry.domain)
+ implementation = implementations.get(config_entry.data["auth_implementation"])
+
+ if implementation is None:
+ raise ValueError("Implementation not available")
+
+ return implementation
+
+
+class OAuth2AuthorizeCallbackView(HomeAssistantView):
+ """OAuth2 Authorization Callback View."""
+
+ requires_auth = False
+ url = AUTH_CALLBACK_PATH
+ name = "auth:external:callback"
+
+ async def get(self, request: web.Request) -> web.Response:
+ """Receive authorization code."""
+ if "code" not in request.query or "state" not in request.query:
+ return web.Response(
+ text=f"Missing code or state parameter in {request.url}"
+ )
+
+ hass = request.app["hass"]
+
+ state = _decode_jwt(hass, request.query["state"])
+
+ if state is None:
+ return web.Response(text=f"Invalid state")
+
+ await hass.config_entries.flow.async_configure(
+ flow_id=state["flow_id"], user_input=request.query["code"]
+ )
+
+ return web.Response(
+ headers={"content-type": "text/html"},
+ text="",
+ )
+
+
+class OAuth2Session:
+ """Session to make requests authenticated with OAuth2."""
+
+ def __init__(
+ self,
+ hass: HomeAssistant,
+ config_entry: config_entries.ConfigEntry,
+ implementation: AbstractOAuth2Implementation,
+ ):
+ """Initialize an OAuth2 session."""
+ self.hass = hass
+ self.config_entry = config_entry
+ self.implementation = implementation
+
+ async def async_ensure_token_valid(self) -> None:
+ """Ensure that the current token is valid."""
+ token = self.config_entry.data["token"]
+
+ if token["expires_at"] > time.time():
+ return
+
+ new_token = await self.implementation.async_refresh_token(token)
+
+ self.hass.config_entries.async_update_entry( # type: ignore
+ self.config_entry, data={**self.config_entry.data, "token": new_token}
+ )
+
+ async def async_request(
+ self, method: str, url: str, **kwargs: Any
+ ) -> client.ClientResponse:
+ """Make a request."""
+ await self.async_ensure_token_valid()
+ return await async_oauth2_request(
+ self.hass, self.config_entry.data["token"], method, url, **kwargs
+ )
+
+
+async def async_oauth2_request(
+ hass: HomeAssistant, token: dict, method: str, url: str, **kwargs: Any
+) -> client.ClientResponse:
+ """Make an OAuth2 authenticated request.
+
+ This method will not refresh tokens. Use OAuth2 session for that.
+ """
+ session = async_get_clientsession(hass)
+
+ return await session.request(
+ method,
+ url,
+ **kwargs,
+ headers={
+ **kwargs.get("headers", {}),
+ "authorization": f"Bearer {token['access_token']}",
+ },
+ )
+
+
+@callback
+def _encode_jwt(hass: HomeAssistant, data: dict) -> str:
+ """JWT encode data."""
+ secret = hass.data.get(DATA_JWT_SECRET)
+
+ if secret is None:
+ secret = hass.data[DATA_JWT_SECRET] = generate_secret()
+
+ return jwt.encode(data, secret, algorithm="HS256").decode()
+
+
+@callback
+def _decode_jwt(hass: HomeAssistant, encoded: str) -> Optional[dict]:
+ """JWT encode data."""
+ secret = cast(str, hass.data.get(DATA_JWT_SECRET))
+
+ try:
+ return jwt.decode(encoded, secret, algorithms=["HS256"])
+ except jwt.InvalidTokenError:
+ return None
diff --git a/requirements_all.txt b/requirements_all.txt
index 951c9800943..58a927c81ab 100644
--- a/requirements_all.txt
+++ b/requirements_all.txt
@@ -1304,7 +1304,7 @@ pymailgunner==1.4
pymediaroom==0.6.4
# homeassistant.components.somfy
-pymfy==0.5.2
+pymfy==0.6.0
# homeassistant.components.xiaomi_tv
pymitv==1.4.3
diff --git a/requirements_test_all.txt b/requirements_test_all.txt
index c9a0013212c..24122915fb5 100644
--- a/requirements_test_all.txt
+++ b/requirements_test_all.txt
@@ -447,7 +447,7 @@ pylitejet==0.1
pymailgunner==1.4
# homeassistant.components.somfy
-pymfy==0.5.2
+pymfy==0.6.0
# homeassistant.components.mochad
pymochad==0.2.0
diff --git a/tests/common.py b/tests/common.py
index 5532e6ccb5c..f40019c5d24 100644
--- a/tests/common.py
+++ b/tests/common.py
@@ -1015,14 +1015,23 @@ def mock_entity_platform(hass, platform_path, module):
hue.light.
"""
domain, platform_name = platform_path.split(".")
- integration_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {})
+ mock_platform(hass, f"{platform_name}.{domain}", module)
+
+
+def mock_platform(hass, platform_path, module=None):
+ """Mock a platform.
+
+ platform_path is in form hue.config_flow.
+ """
+ domain, platform_name = platform_path.split(".")
+ integration_cache = hass.data.setdefault(loader.DATA_INTEGRATIONS, {})
module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {})
- if platform_name not in integration_cache:
- mock_integration(hass, MockModule(platform_name))
+ if domain not in integration_cache:
+ mock_integration(hass, MockModule(domain))
_LOGGER.info("Adding mock integration platform: %s", platform_path)
- module_cache["{}.{}".format(platform_name, domain)] = module
+ module_cache[platform_path] = module or Mock()
def async_capture_events(hass, event_name):
diff --git a/tests/components/somfy/test_config_flow.py b/tests/components/somfy/test_config_flow.py
index cbc3784e3f5..d42e7b8e367 100644
--- a/tests/components/somfy/test_config_flow.py
+++ b/tests/components/somfy/test_config_flow.py
@@ -1,19 +1,35 @@
"""Tests for the Somfy config flow."""
import asyncio
-from unittest.mock import Mock, patch
+from unittest.mock import patch
-from pymfy.api.somfy_api import SomfyApi
+import pytest
-from homeassistant import data_entry_flow
+from homeassistant import data_entry_flow, setup, config_entries
from homeassistant.components.somfy import config_flow, DOMAIN
-from homeassistant.components.somfy.config_flow import register_flow_implementation
-from tests.common import MockConfigEntry, mock_coro
+from homeassistant.helpers import config_entry_oauth2_flow
+
+from tests.common import MockConfigEntry
CLIENT_SECRET_VALUE = "5678"
CLIENT_ID_VALUE = "1234"
-AUTH_URL = "http://somfy.com"
+
+@pytest.fixture()
+async def mock_impl(hass):
+ """Mock implementation."""
+ await setup.async_setup_component(hass, "http", {})
+
+ impl = config_entry_oauth2_flow.LocalOAuth2Implementation(
+ hass,
+ DOMAIN,
+ CLIENT_ID_VALUE,
+ CLIENT_SECRET_VALUE,
+ "https://accounts.somfy.com/oauth/oauth/v2/auth",
+ "https://accounts.somfy.com/oauth/oauth/v2/token",
+ )
+ config_flow.SomfyFlowHandler.async_register_implementation(hass, impl)
+ return impl
async def test_abort_if_no_configuration(hass):
@@ -30,47 +46,84 @@ async def test_abort_if_existing_entry(hass):
flow = config_flow.SomfyFlowHandler()
flow.hass = hass
MockConfigEntry(domain=DOMAIN).add_to_hass(hass)
- result = await flow.async_step_import()
- assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT
- assert result["reason"] == "already_setup"
+
result = await flow.async_step_user()
assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT
assert result["reason"] == "already_setup"
-async def test_full_flow(hass):
- """Check classic use case."""
- hass.data[DOMAIN] = {}
- register_flow_implementation(hass, CLIENT_ID_VALUE, CLIENT_SECRET_VALUE)
- flow = config_flow.SomfyFlowHandler()
- flow.hass = hass
- hass.config.api = Mock(base_url="https://example.com")
- flow._get_authorization_url = Mock(return_value=mock_coro((AUTH_URL, "state")))
- result = await flow.async_step_import()
+async def test_full_flow(hass, aiohttp_client, aioclient_mock):
+ """Check full flow."""
+ assert await setup.async_setup_component(
+ hass,
+ "somfy",
+ {
+ "somfy": {
+ "client_id": CLIENT_ID_VALUE,
+ "client_secret": CLIENT_SECRET_VALUE,
+ },
+ "http": {"base_url": "https://example.com"},
+ },
+ )
+
+ result = await hass.config_entries.flow.async_init(
+ "somfy", context={"source": config_entries.SOURCE_USER}
+ )
+ state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
+
assert result["type"] == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP
- assert result["url"] == AUTH_URL
- result = await flow.async_step_auth("my_super_code")
- assert result["type"] == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP_DONE
- assert result["step_id"] == "creation"
- assert flow.code == "my_super_code"
- with patch.object(
- SomfyApi, "request_token", return_value={"access_token": "super_token"}
- ):
- result = await flow.async_step_creation()
- assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
- assert result["data"]["refresh_args"] == {
- "client_id": CLIENT_ID_VALUE,
- "client_secret": CLIENT_SECRET_VALUE,
+ assert result["url"] == (
+ "https://accounts.somfy.com/oauth/oauth/v2/auth"
+ f"?response_type=code&client_id={CLIENT_ID_VALUE}"
+ "&redirect_uri=https://example.com/auth/external/callback"
+ f"&state={state}"
+ )
+
+ client = await aiohttp_client(hass.http.app)
+ resp = await client.get(f"/auth/external/callback?code=abcd&state={state}")
+ assert resp.status == 200
+ assert resp.headers["content-type"] == "text/html; charset=utf-8"
+
+ aioclient_mock.post(
+ "https://accounts.somfy.com/oauth/oauth/v2/token",
+ json={
+ "refresh_token": "mock-refresh-token",
+ "access_token": "mock-access-token",
+ "type": "Bearer",
+ "expires_in": 60,
+ },
+ )
+
+ with patch("homeassistant.components.somfy.api.ConfigEntrySomfyApi"):
+ result = await hass.config_entries.flow.async_configure(result["flow_id"])
+
+ assert result["data"]["auth_implementation"] == "somfy"
+
+ result["data"]["token"].pop("expires_at")
+ assert result["data"]["token"] == {
+ "refresh_token": "mock-refresh-token",
+ "access_token": "mock-access-token",
+ "type": "Bearer",
+ "expires_in": 60,
}
- assert result["title"] == "Somfy"
- assert result["data"]["token"] == {"access_token": "super_token"}
+
+ assert "somfy" in hass.config.components
+ entry = hass.config_entries.async_entries("somfy")[0]
+ assert entry.state == config_entries.ENTRY_STATE_LOADED
+
+ assert await hass.config_entries.async_unload(entry.entry_id)
+ assert entry.state == config_entries.ENTRY_STATE_NOT_LOADED
-async def test_abort_if_authorization_timeout(hass):
+async def test_abort_if_authorization_timeout(hass, mock_impl):
"""Check Somfy authorization timeout."""
flow = config_flow.SomfyFlowHandler()
flow.hass = hass
- flow._get_authorization_url = Mock(side_effect=asyncio.TimeoutError)
- result = await flow.async_step_auth()
+
+ with patch.object(
+ mock_impl, "async_generate_authorize_url", side_effect=asyncio.TimeoutError
+ ):
+ result = await flow.async_step_user()
+
assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT
assert result["reason"] == "authorize_url_timeout"
diff --git a/tests/helpers/test_config_entry_oauth2_flow.py b/tests/helpers/test_config_entry_oauth2_flow.py
new file mode 100644
index 00000000000..e47dd834bf7
--- /dev/null
+++ b/tests/helpers/test_config_entry_oauth2_flow.py
@@ -0,0 +1,266 @@
+"""Tests for the Somfy config flow."""
+import asyncio
+import logging
+from unittest.mock import patch
+import time
+
+import pytest
+
+from homeassistant import data_entry_flow, setup, config_entries
+from homeassistant.helpers import config_entry_oauth2_flow
+
+from tests.common import mock_platform, MockConfigEntry
+
+TEST_DOMAIN = "oauth2_test"
+CLIENT_SECRET = "5678"
+CLIENT_ID = "1234"
+REFRESH_TOKEN = "mock-refresh-token"
+ACCESS_TOKEN_1 = "mock-access-token-1"
+ACCESS_TOKEN_2 = "mock-access-token-2"
+AUTHORIZE_URL = "https://example.como/auth/authorize"
+TOKEN_URL = "https://example.como/auth/token"
+
+
+@pytest.fixture
+async def local_impl(hass):
+ """Local implementation."""
+ assert await setup.async_setup_component(hass, "http", {})
+ return config_entry_oauth2_flow.LocalOAuth2Implementation(
+ hass, TEST_DOMAIN, CLIENT_ID, CLIENT_SECRET, AUTHORIZE_URL, TOKEN_URL
+ )
+
+
+@pytest.fixture
+def flow_handler(hass):
+ """Return a registered config flow."""
+
+ mock_platform(hass, f"{TEST_DOMAIN}.config_flow")
+
+ class TestFlowHandler(config_entry_oauth2_flow.AbstractOAuth2FlowHandler):
+ """Test flow handler."""
+
+ DOMAIN = TEST_DOMAIN
+
+ @property
+ def logger(self) -> logging.Logger:
+ """Return logger."""
+ return logging.getLogger(__name__)
+
+ @property
+ def extra_authorize_data(self) -> dict:
+ """Extra data that needs to be appended to the authorize url."""
+ return {"scope": "read write"}
+
+ with patch.dict(config_entries.HANDLERS, {TEST_DOMAIN: TestFlowHandler}):
+ yield TestFlowHandler
+
+
+class MockOAuth2Implementation(config_entry_oauth2_flow.AbstractOAuth2Implementation):
+ """Mock implementation for testing."""
+
+ @property
+ def name(self) -> str:
+ """Name of the implementation."""
+ return "Mock"
+
+ @property
+ def domain(self) -> str:
+ """Domain that is providing the implementation."""
+ return "test"
+
+ async def async_generate_authorize_url(self, flow_id: str) -> str:
+ """Generate a url for the user to authorize."""
+ return "http://example.com/auth"
+
+ async def async_resolve_external_data(self, external_data) -> dict:
+ """Resolve external data to tokens."""
+ return external_data
+
+ async def _async_refresh_token(self, token: dict) -> dict:
+ """Refresh a token."""
+ raise NotImplementedError()
+
+
+def test_inherit_enforces_domain_set():
+ """Test we enforce setting DOMAIN."""
+
+ class TestFlowHandler(config_entry_oauth2_flow.AbstractOAuth2FlowHandler):
+ """Test flow handler."""
+
+ @property
+ def logger(self) -> logging.Logger:
+ """Return logger."""
+ return logging.getLogger(__name__)
+
+ with patch.dict(config_entries.HANDLERS, {TEST_DOMAIN: TestFlowHandler}):
+ with pytest.raises(TypeError):
+ TestFlowHandler()
+
+
+async def test_abort_if_no_implementation(hass, flow_handler):
+ """Check flow abort when no implementations."""
+ flow = flow_handler()
+ flow.hass = hass
+ result = await flow.async_step_user()
+ assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT
+ assert result["reason"] == "missing_configuration"
+
+
+async def test_abort_if_authorization_timeout(hass, flow_handler, local_impl):
+ """Check timeout generating authorization url."""
+ flow_handler.async_register_implementation(hass, local_impl)
+
+ flow = flow_handler()
+ flow.hass = hass
+
+ with patch.object(
+ local_impl, "async_generate_authorize_url", side_effect=asyncio.TimeoutError
+ ):
+ result = await flow.async_step_user()
+
+ assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT
+ assert result["reason"] == "authorize_url_timeout"
+
+
+async def test_full_flow(
+ hass, flow_handler, local_impl, aiohttp_client, aioclient_mock
+):
+ """Check full flow."""
+ hass.config.api.base_url = "https://example.com"
+ flow_handler.async_register_implementation(hass, local_impl)
+ config_entry_oauth2_flow.async_register_implementation(
+ hass, TEST_DOMAIN, MockOAuth2Implementation()
+ )
+
+ result = await hass.config_entries.flow.async_init(
+ TEST_DOMAIN, context={"source": config_entries.SOURCE_USER}
+ )
+
+ assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
+ assert result["step_id"] == "pick_implementation"
+
+ # Pick implementation
+ result = await hass.config_entries.flow.async_configure(
+ result["flow_id"], user_input={"implementation": TEST_DOMAIN}
+ )
+
+ state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]})
+
+ assert result["type"] == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP
+ assert result["url"] == (
+ f"{AUTHORIZE_URL}?response_type=code&client_id={CLIENT_ID}"
+ "&redirect_uri=https://example.com/auth/external/callback"
+ f"&state={state}&scope=read+write"
+ )
+
+ client = await aiohttp_client(hass.http.app)
+ resp = await client.get(f"/auth/external/callback?code=abcd&state={state}")
+ assert resp.status == 200
+ assert resp.headers["content-type"] == "text/html; charset=utf-8"
+
+ aioclient_mock.post(
+ TOKEN_URL,
+ json={
+ "refresh_token": REFRESH_TOKEN,
+ "access_token": ACCESS_TOKEN_1,
+ "type": "bearer",
+ "expires_in": 60,
+ },
+ )
+
+ result = await hass.config_entries.flow.async_configure(result["flow_id"])
+
+ assert result["data"]["auth_implementation"] == TEST_DOMAIN
+
+ result["data"]["token"].pop("expires_at")
+ assert result["data"]["token"] == {
+ "refresh_token": REFRESH_TOKEN,
+ "access_token": ACCESS_TOKEN_1,
+ "type": "bearer",
+ "expires_in": 60,
+ }
+
+ entry = hass.config_entries.async_entries(TEST_DOMAIN)[0]
+
+ assert (
+ await config_entry_oauth2_flow.async_get_config_entry_implementation(
+ hass, entry
+ )
+ is local_impl
+ )
+
+
+async def test_local_refresh_token(hass, local_impl, aioclient_mock):
+ """Test we can refresh token."""
+ aioclient_mock.post(
+ TOKEN_URL, json={"access_token": ACCESS_TOKEN_2, "expires_in": 100}
+ )
+
+ new_tokens = await local_impl.async_refresh_token(
+ {
+ "refresh_token": REFRESH_TOKEN,
+ "access_token": ACCESS_TOKEN_1,
+ "type": "bearer",
+ "expires_in": 60,
+ }
+ )
+ new_tokens.pop("expires_at")
+
+ assert new_tokens == {
+ "refresh_token": REFRESH_TOKEN,
+ "access_token": ACCESS_TOKEN_2,
+ "type": "bearer",
+ "expires_in": 100,
+ }
+
+ assert len(aioclient_mock.mock_calls) == 1
+ assert aioclient_mock.mock_calls[0][2] == {
+ "client_id": CLIENT_ID,
+ "client_secret": CLIENT_SECRET,
+ "grant_type": "refresh_token",
+ "refresh_token": REFRESH_TOKEN,
+ }
+
+
+async def test_oauth_session(hass, flow_handler, local_impl, aioclient_mock):
+ """Test the OAuth2 session helper."""
+ flow_handler.async_register_implementation(hass, local_impl)
+
+ aioclient_mock.post(
+ TOKEN_URL, json={"access_token": ACCESS_TOKEN_2, "expires_in": 100}
+ )
+
+ aioclient_mock.post("https://example.com", status=201)
+
+ config_entry = MockConfigEntry(
+ domain=TEST_DOMAIN,
+ data={
+ "auth_implementation": TEST_DOMAIN,
+ "token": {
+ "refresh_token": REFRESH_TOKEN,
+ "access_token": ACCESS_TOKEN_1,
+ "expires_in": 10,
+ "expires_at": 0, # Forces a refresh,
+ "token_type": "bearer",
+ "random_other_data": "should_stay",
+ },
+ },
+ )
+
+ now = time.time()
+ session = config_entry_oauth2_flow.OAuth2Session(hass, config_entry, local_impl)
+ resp = await session.async_request("post", "https://example.com")
+ assert resp.status == 201
+
+ # Refresh token, make request
+ assert len(aioclient_mock.mock_calls) == 2
+
+ assert (
+ aioclient_mock.mock_calls[1][3]["authorization"] == f"Bearer {ACCESS_TOKEN_2}"
+ )
+
+ assert config_entry.data["token"]["refresh_token"] == REFRESH_TOKEN
+ assert config_entry.data["token"]["access_token"] == ACCESS_TOKEN_2
+ assert config_entry.data["token"]["expires_in"] == 100
+ assert config_entry.data["token"]["random_other_data"] == "should_stay"
+ assert round(config_entry.data["token"]["expires_at"] - now) == 100