diff --git a/homeassistant/bootstrap.py b/homeassistant/bootstrap.py index e399205ec70..6118f4f2bd7 100644 --- a/homeassistant/bootstrap.py +++ b/homeassistant/bootstrap.py @@ -260,7 +260,7 @@ def _get_domains(hass: core.HomeAssistant, config: Dict[str, Any]) -> Set[str]: domains = set(key.split(" ")[0] for key in config.keys() if key != core.DOMAIN) # Add config entry domains - domains.update(hass.config_entries.async_domains()) # type: ignore + domains.update(hass.config_entries.async_domains()) # Make sure the Hass.io component is loaded if "HASSIO" in os.environ: diff --git a/homeassistant/components/somfy/__init__.py b/homeassistant/components/somfy/__init__.py index 2c7c71d7a69..cd5960bf6b1 100644 --- a/homeassistant/components/somfy/__init__.py +++ b/homeassistant/components/somfy/__init__.py @@ -4,21 +4,21 @@ Support for Somfy hubs. For more details about this component, please refer to the documentation at https://home-assistant.io/integrations/somfy/ """ +import asyncio import logging from datetime import timedelta -from functools import partial import voluptuous as vol -import homeassistant.helpers.config_validation as cv -from homeassistant import config_entries +from homeassistant.helpers import config_validation as cv, config_entry_oauth2_flow from homeassistant.components.somfy import config_flow from homeassistant.config_entries import ConfigEntry -from homeassistant.const import CONF_TOKEN from homeassistant.helpers.entity import Entity from homeassistant.helpers.typing import HomeAssistantType from homeassistant.util import Throttle +from . import api + API = "api" DEVICES = "devices" @@ -52,19 +52,21 @@ SOMFY_COMPONENTS = ["cover"] async def async_setup(hass, config): """Set up the Somfy component.""" + hass.data[DOMAIN] = {} + if DOMAIN not in config: return True - hass.data[DOMAIN] = {} - - config_flow.register_flow_implementation( - hass, config[DOMAIN][CONF_CLIENT_ID], config[DOMAIN][CONF_CLIENT_SECRET] - ) - - hass.async_create_task( - hass.config_entries.flow.async_init( - DOMAIN, context={"source": config_entries.SOURCE_IMPORT} - ) + config_flow.SomfyFlowHandler.async_register_implementation( + hass, + config_entry_oauth2_flow.LocalOAuth2Implementation( + hass, + DOMAIN, + config[DOMAIN][CONF_CLIENT_ID], + config[DOMAIN][CONF_CLIENT_SECRET], + "https://accounts.somfy.com/oauth/oauth/v2/auth", + "https://accounts.somfy.com/oauth/oauth/v2/token", + ), ) return True @@ -72,25 +74,18 @@ async def async_setup(hass, config): async def async_setup_entry(hass: HomeAssistantType, entry: ConfigEntry): """Set up Somfy from a config entry.""" - - def token_saver(token): - _LOGGER.debug("Saving updated token") - entry.data[CONF_TOKEN] = token - update_entry = partial( - hass.config_entries.async_update_entry, data={**entry.data} + # Backwards compat + if "auth_implementation" not in entry.data: + hass.config_entries.async_update_entry( + entry, data={**entry.data, "auth_implementation": DOMAIN} ) - hass.add_job(update_entry, entry) - # Force token update. - from pymfy.api.somfy_api import SomfyApi - - hass.data[DOMAIN][API] = SomfyApi( - entry.data["refresh_args"]["client_id"], - entry.data["refresh_args"]["client_secret"], - token=entry.data[CONF_TOKEN], - token_updater=token_saver, + implementation = await config_entry_oauth2_flow.async_get_config_entry_implementation( + hass, entry ) + hass.data[DOMAIN][API] = api.ConfigEntrySomfyApi(hass, entry, implementation) + await update_all_devices(hass) for component in SOMFY_COMPONENTS: @@ -104,16 +99,22 @@ async def async_setup_entry(hass: HomeAssistantType, entry: ConfigEntry): async def async_unload_entry(hass: HomeAssistantType, entry: ConfigEntry): """Unload a config entry.""" hass.data[DOMAIN].pop(API, None) + await asyncio.gather( + *[ + hass.config_entries.async_forward_entry_unload(entry, component) + for component in SOMFY_COMPONENTS + ] + ) return True class SomfyEntity(Entity): """Representation of a generic Somfy device.""" - def __init__(self, device, api): + def __init__(self, device, somfy_api): """Initialize the Somfy device.""" self.device = device - self.api = api + self.api = somfy_api @property def unique_id(self): diff --git a/homeassistant/components/somfy/api.py b/homeassistant/components/somfy/api.py new file mode 100644 index 00000000000..3e7bcf9deb4 --- /dev/null +++ b/homeassistant/components/somfy/api.py @@ -0,0 +1,55 @@ +"""API for Somfy bound to HASS OAuth.""" +from asyncio import run_coroutine_threadsafe +from functools import partial + +import requests +from pymfy.api import somfy_api + +from homeassistant import core, config_entries +from homeassistant.helpers import config_entry_oauth2_flow + + +class ConfigEntrySomfyApi(somfy_api.AbstractSomfyApi): + """Provide a Somfy API tied into an OAuth2 based config entry.""" + + def __init__( + self, + hass: core.HomeAssistant, + config_entry: config_entries.ConfigEntry, + implementation: config_entry_oauth2_flow.AbstractOAuth2Implementation, + ): + """Initialize the Config Entry Somfy API.""" + self.hass = hass + self.config_entry = config_entry + self.session = config_entry_oauth2_flow.OAuth2Session( + hass, config_entry, implementation + ) + + def get(self, path): + """Fetch a URL from the Somfy API.""" + return run_coroutine_threadsafe( + self._request("get", path), self.hass.loop + ).result() + + def post(self, path, *, json): + """Post data to the Somfy API.""" + return run_coroutine_threadsafe( + self._request("post", path, json=json), self.hass.loop + ).result() + + async def _request(self, method, path, **kwargs): + """Make a request.""" + await self.session.async_ensure_token_valid() + + return await self.hass.async_add_executor_job( + partial( + requests.request, + method, + f"{self.base_url}{path}", + **kwargs, + headers={ + **kwargs.get("headers", {}), + "authorization": f"Bearer {self.config_entry.data['token']['access_token']}", + }, + ) + ) diff --git a/homeassistant/components/somfy/config_flow.py b/homeassistant/components/somfy/config_flow.py index 9f3c58c8ffb..cb180d4e247 100644 --- a/homeassistant/components/somfy/config_flow.py +++ b/homeassistant/components/somfy/config_flow.py @@ -1,141 +1,28 @@ """Config flow for Somfy.""" -import asyncio import logging -import async_timeout - from homeassistant import config_entries -from homeassistant.components.http import HomeAssistantView -from homeassistant.core import callback -from .const import CLIENT_ID, CLIENT_SECRET, DOMAIN - -AUTH_CALLBACK_PATH = "/auth/somfy/callback" -AUTH_CALLBACK_NAME = "auth:somfy:callback" +from homeassistant.helpers import config_entry_oauth2_flow +from .const import DOMAIN _LOGGER = logging.getLogger(__name__) -@callback -def register_flow_implementation(hass, client_id, client_secret): - """Register a flow implementation. +@config_entries.HANDLERS.register(DOMAIN) +class SomfyFlowHandler(config_entry_oauth2_flow.AbstractOAuth2FlowHandler): + """Config flow to handle Somfy OAuth2 authentication.""" - client_id: Client id. - client_secret: Client secret. - """ - hass.data[DOMAIN][CLIENT_ID] = client_id - hass.data[DOMAIN][CLIENT_SECRET] = client_secret - - -@config_entries.HANDLERS.register("somfy") -class SomfyFlowHandler(config_entries.ConfigFlow): - """Handle a config flow.""" - - VERSION = 1 + DOMAIN = DOMAIN CONNECTION_CLASS = config_entries.CONN_CLASS_CLOUD_POLL - def __init__(self): - """Instantiate config flow.""" - self.code = None - - async def async_step_import(self, user_input=None): - """Handle external yaml configuration.""" - if self.hass.config_entries.async_entries(DOMAIN): - return self.async_abort(reason="already_setup") - return await self.async_step_auth() + @property + def logger(self) -> logging.Logger: + """Return logger.""" + return logging.getLogger(__name__) async def async_step_user(self, user_input=None): """Handle a flow start.""" if self.hass.config_entries.async_entries(DOMAIN): return self.async_abort(reason="already_setup") - if DOMAIN not in self.hass.data: - return self.async_abort(reason="missing_configuration") - - return await self.async_step_auth() - - async def async_step_auth(self, user_input=None): - """Create an entry for auth.""" - # Flow has been triggered from Somfy website - if user_input: - return await self.async_step_code(user_input) - - try: - with async_timeout.timeout(10): - url, _ = await self._get_authorization_url() - except asyncio.TimeoutError: - return self.async_abort(reason="authorize_url_timeout") - - return self.async_external_step(step_id="auth", url=url) - - async def _get_authorization_url(self): - """Get Somfy authorization url.""" - from pymfy.api.somfy_api import SomfyApi - - client_id = self.hass.data[DOMAIN][CLIENT_ID] - client_secret = self.hass.data[DOMAIN][CLIENT_SECRET] - redirect_uri = f"{self.hass.config.api.base_url}{AUTH_CALLBACK_PATH}" - api = SomfyApi(client_id, client_secret, redirect_uri) - - self.hass.http.register_view(SomfyAuthCallbackView()) - # Thanks to the state, we can forward the flow id to Somfy that will - # add it in the callback. - return await self.hass.async_add_executor_job( - api.get_authorization_url, self.flow_id - ) - - async def async_step_code(self, code): - """Received code for authentication.""" - self.code = code - return self.async_external_step_done(next_step_id="creation") - - async def async_step_creation(self, user_input=None): - """Create Somfy api and entries.""" - client_id = self.hass.data[DOMAIN][CLIENT_ID] - client_secret = self.hass.data[DOMAIN][CLIENT_SECRET] - code = self.code - from pymfy.api.somfy_api import SomfyApi - - redirect_uri = f"{self.hass.config.api.base_url}{AUTH_CALLBACK_PATH}" - api = SomfyApi(client_id, client_secret, redirect_uri) - token = await self.hass.async_add_executor_job(api.request_token, None, code) - _LOGGER.info("Successfully authenticated Somfy") - return self.async_create_entry( - title="Somfy", - data={ - "token": token, - "refresh_args": { - "client_id": client_id, - "client_secret": client_secret, - }, - }, - ) - - -class SomfyAuthCallbackView(HomeAssistantView): - """Somfy Authorization Callback View.""" - - requires_auth = False - url = AUTH_CALLBACK_PATH - name = AUTH_CALLBACK_NAME - - @staticmethod - async def get(request): - """Receive authorization code.""" - from aiohttp import web_response - - if "code" not in request.query or "state" not in request.query: - return web_response.Response( - text="Missing code or state parameter in " + request.url - ) - - hass = request.app["hass"] - hass.async_create_task( - hass.config_entries.flow.async_configure( - flow_id=request.query["state"], user_input=request.query["code"] - ) - ) - - return web_response.Response( - headers={"content-type": "text/html"}, - text="", - ) + return await super().async_step_user(user_input) diff --git a/homeassistant/components/somfy/const.py b/homeassistant/components/somfy/const.py index 99fafb71bff..8765e37e6d6 100644 --- a/homeassistant/components/somfy/const.py +++ b/homeassistant/components/somfy/const.py @@ -1,5 +1,3 @@ """Define constants for the Somfy component.""" DOMAIN = "somfy" -CLIENT_ID = "client_id" -CLIENT_SECRET = "client_secret" diff --git a/homeassistant/components/somfy/manifest.json b/homeassistant/components/somfy/manifest.json index 83b50684fda..a34023f76ff 100644 --- a/homeassistant/components/somfy/manifest.json +++ b/homeassistant/components/somfy/manifest.json @@ -3,11 +3,7 @@ "name": "Somfy Open API", "config_flow": true, "documentation": "https://www.home-assistant.io/integrations/somfy", - "dependencies": [], - "codeowners": [ - "@tetienne" - ], - "requirements": [ - "pymfy==0.5.2" - ] -} \ No newline at end of file + "dependencies": ["http"], + "codeowners": ["@tetienne"], + "requirements": ["pymfy==0.6.0"] +} diff --git a/homeassistant/config_entries.py b/homeassistant/config_entries.py index 8a40cff1bd5..f8c7c7a9da1 100644 --- a/homeassistant/config_entries.py +++ b/homeassistant/config_entries.py @@ -337,7 +337,7 @@ class ConfigEntry: return False if result: # pylint: disable=protected-access - hass.config_entries._async_schedule_save() # type: ignore + hass.config_entries._async_schedule_save() return result except Exception: # pylint: disable=broad-except _LOGGER.exception( diff --git a/homeassistant/core.py b/homeassistant/core.py index 90d197906cb..ec11b14edaa 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -77,7 +77,8 @@ from homeassistant.util.unit_system import ( # NOQA # Typing imports that create a circular dependency # pylint: disable=using-constant-test if TYPE_CHECKING: - from homeassistant.config_entries import ConfigEntries # noqa + from homeassistant.config_entries import ConfigEntries + from homeassistant.components.http import HomeAssistantHTTP # pylint: disable=invalid-name T = TypeVar("T") @@ -162,6 +163,9 @@ class CoreState(enum.Enum): class HomeAssistant: """Root object of the Home Assistant home automation.""" + http: "HomeAssistantHTTP" = None # type: ignore + config_entries: "ConfigEntries" = None # type: ignore + def __init__(self, loop: Optional[asyncio.events.AbstractEventLoop] = None) -> None: """Initialize new Home Assistant object.""" self.loop: asyncio.events.AbstractEventLoop = (loop or asyncio.get_event_loop()) @@ -186,9 +190,6 @@ class HomeAssistant: self.data: dict = {} self.state = CoreState.not_running self.exit_code = 0 - self.config_entries: Optional[ - ConfigEntries # pylint: disable=used-before-assignment - ] = None # If not None, use to signal end-of-loop self._stopped: Optional[asyncio.Event] = None diff --git a/homeassistant/data_entry_flow.py b/homeassistant/data_entry_flow.py index 0bc27498f76..c06c69d9213 100644 --- a/homeassistant/data_entry_flow.py +++ b/homeassistant/data_entry_flow.py @@ -168,7 +168,7 @@ class FlowHandler: """Handle the configuration flow of a component.""" # Set by flow manager - flow_id: Optional[str] = None + flow_id: str = None # type: ignore hass: Optional[HomeAssistant] = None handler: Optional[Hashable] = None cur_step: Optional[Dict[str, str]] = None diff --git a/homeassistant/helpers/config_entry_oauth2_flow.py b/homeassistant/helpers/config_entry_oauth2_flow.py new file mode 100644 index 00000000000..043a28cac27 --- /dev/null +++ b/homeassistant/helpers/config_entry_oauth2_flow.py @@ -0,0 +1,420 @@ +"""Config Flow using OAuth2. + +This module exists of the following parts: + - OAuth2 config flow which supports multiple OAuth2 implementations + - OAuth2 implementation that works with local provided client ID/secret + +""" +import asyncio +from abc import ABCMeta, ABC, abstractmethod +import logging +from typing import Optional, Any, Dict, cast +import time + +import async_timeout +from aiohttp import web, client +import jwt +import voluptuous as vol +from yarl import URL + +from homeassistant.auth.util import generate_secret +from homeassistant.core import HomeAssistant, callback +from homeassistant import config_entries +from homeassistant.components.http import HomeAssistantView + +from .aiohttp_client import async_get_clientsession + + +DATA_JWT_SECRET = "oauth2_jwt_secret" +DATA_VIEW_REGISTERED = "oauth2_view_reg" +DATA_IMPLEMENTATIONS = "oauth2_impl" +AUTH_CALLBACK_PATH = "/auth/external/callback" + + +class AbstractOAuth2Implementation(ABC): + """Base class to abstract OAuth2 authentication.""" + + @property + @abstractmethod + def name(self) -> str: + """Name of the implementation.""" + + @property + @abstractmethod + def domain(self) -> str: + """Domain that is providing the implementation.""" + + @abstractmethod + async def async_generate_authorize_url(self, flow_id: str) -> str: + """Generate a url for the user to authorize. + + This step is called when a config flow is initialized. It should redirect the + user to the vendor website where they can authorize Home Assistant. + + The implementation is responsible to get notified when the user is authorized + and pass this to the specified config flow. Do as little work as possible once + notified. You can do the work inside async_resolve_external_data. This will + give the best UX. + + Pass external data in with: + + ```python + await hass.config_entries.flow.async_configure( + flow_id=flow_id, user_input=external_data + ) + ``` + """ + + @abstractmethod + async def async_resolve_external_data(self, external_data: Any) -> dict: + """Resolve external data to tokens. + + Turn the data that the implementation passed to the config flow as external + step data into tokens. These tokens will be stored as 'token' in the + config entry data. + """ + + async def async_refresh_token(self, token: dict) -> dict: + """Refresh a token and update expires info.""" + new_token = await self._async_refresh_token(token) + new_token["expires_at"] = time.time() + new_token["expires_in"] + return new_token + + @abstractmethod + async def _async_refresh_token(self, token: dict) -> dict: + """Refresh a token.""" + + +class LocalOAuth2Implementation(AbstractOAuth2Implementation): + """Local OAuth2 implementation.""" + + def __init__( + self, + hass: HomeAssistant, + domain: str, + client_id: str, + client_secret: str, + authorize_url: str, + token_url: str, + ): + """Initialize local auth implementation.""" + self.hass = hass + self._domain = domain + self.client_id = client_id + self.client_secret = client_secret + self.authorize_url = authorize_url + self.token_url = token_url + + @property + def name(self) -> str: + """Name of the implementation.""" + return "Configuration.yaml" + + @property + def domain(self) -> str: + """Domain providing the implementation.""" + return self._domain + + @property + def redirect_uri(self) -> str: + """Return the redirect uri.""" + return f"{self.hass.config.api.base_url}{AUTH_CALLBACK_PATH}" # type: ignore + + async def async_generate_authorize_url(self, flow_id: str) -> str: + """Generate a url for the user to authorize.""" + return str( + URL(self.authorize_url).with_query( + { + "response_type": "code", + "client_id": self.client_id, + "redirect_uri": self.redirect_uri, + "state": _encode_jwt(self.hass, {"flow_id": flow_id}), + } + ) + ) + + async def async_resolve_external_data(self, external_data: Any) -> dict: + """Resolve the authorization code to tokens.""" + return await self._token_request( + { + "grant_type": "authorization_code", + "code": external_data, + "redirect_uri": self.redirect_uri, + } + ) + + async def _async_refresh_token(self, token: dict) -> dict: + """Refresh tokens.""" + new_token = await self._token_request( + { + "grant_type": "refresh_token", + "client_id": self.client_id, + "refresh_token": token["refresh_token"], + } + ) + return {**token, **new_token} + + async def _token_request(self, data: dict) -> dict: + """Make a token request.""" + session = async_get_clientsession(self.hass) + + data["client_id"] = self.client_id + + if self.client_secret is not None: + data["client_secret"] = self.client_secret + + resp = await session.post(self.token_url, data=data) + resp.raise_for_status() + return cast(dict, await resp.json()) + + +class AbstractOAuth2FlowHandler(config_entries.ConfigFlow, metaclass=ABCMeta): + """Handle a config flow.""" + + DOMAIN = "" + + VERSION = 1 + CONNECTION_CLASS = config_entries.CONN_CLASS_UNKNOWN + + def __init__(self) -> None: + """Instantiate config flow.""" + if self.DOMAIN == "": + raise TypeError( + f"Can't instantiate class {self.__class__.__name__} without DOMAIN being set" + ) + + self.external_data: Any = None + self.flow_impl: AbstractOAuth2Implementation = None # type: ignore + + @property + @abstractmethod + def logger(self) -> logging.Logger: + """Return logger.""" + + @property + def extra_authorize_data(self) -> dict: + """Extra data that needs to be appended to the authorize url.""" + return {} + + async def async_step_pick_implementation(self, user_input: dict = None) -> dict: + """Handle a flow start.""" + assert self.hass + implementations = await async_get_implementations(self.hass, self.DOMAIN) + + if user_input is not None: + self.flow_impl = implementations[user_input["implementation"]] + return await self.async_step_auth() + + if not implementations: + return self.async_abort(reason="missing_configuration") + + if len(implementations) == 1: + # Pick first implementation as we have only one. + self.flow_impl = list(implementations.values())[0] + return await self.async_step_auth() + + return self.async_show_form( + step_id="pick_implementation", + data_schema=vol.Schema( + { + vol.Required( + "implementation", default=list(implementations.keys())[0] + ): vol.In({key: impl.name for key, impl in implementations.items()}) + } + ), + ) + + async def async_step_auth(self, user_input: dict = None) -> dict: + """Create an entry for auth.""" + # Flow has been triggered by external data + if user_input: + self.external_data = user_input + return self.async_external_step_done(next_step_id="creation") + + try: + with async_timeout.timeout(10): + url = await self.flow_impl.async_generate_authorize_url(self.flow_id) + except asyncio.TimeoutError: + return self.async_abort(reason="authorize_url_timeout") + + url = str(URL(url).update_query(self.extra_authorize_data)) + + return self.async_external_step(step_id="auth", url=url) + + async def async_step_creation(self, user_input: dict = None) -> dict: + """Create config entry from external data.""" + token = await self.flow_impl.async_resolve_external_data(self.external_data) + token["expires_at"] = time.time() + token["expires_in"] + + self.logger.info("Successfully authenticated") + + return await self.async_oauth_create_entry( + {"auth_implementation": self.flow_impl.domain, "token": token} + ) + + async def async_oauth_create_entry(self, data: dict) -> dict: + """Create an entry for the flow. + + Ok to override if you want to fetch extra info or even add another step. + """ + return self.async_create_entry(title=self.flow_impl.name, data=data) + + async_step_user = async_step_pick_implementation + async_step_ssdp = async_step_pick_implementation + async_step_zeroconf = async_step_pick_implementation + async_step_homekit = async_step_pick_implementation + + @classmethod + def async_register_implementation( + cls, hass: HomeAssistant, local_impl: LocalOAuth2Implementation + ) -> None: + """Register a local implementation.""" + async_register_implementation(hass, cls.DOMAIN, local_impl) + + +@callback +def async_register_implementation( + hass: HomeAssistant, domain: str, implementation: AbstractOAuth2Implementation +) -> None: + """Register an OAuth2 flow implementation for an integration.""" + if isinstance(implementation, LocalOAuth2Implementation) and not hass.data.get( + DATA_VIEW_REGISTERED, False + ): + hass.http.register_view(OAuth2AuthorizeCallbackView()) # type: ignore + hass.data[DATA_VIEW_REGISTERED] = True + + implementations = hass.data.setdefault(DATA_IMPLEMENTATIONS, {}) + implementations.setdefault(domain, {})[implementation.domain] = implementation + + +async def async_get_implementations( + hass: HomeAssistant, domain: str +) -> Dict[str, AbstractOAuth2Implementation]: + """Return OAuth2 implementations for specified domain.""" + return cast( + Dict[str, AbstractOAuth2Implementation], + hass.data.setdefault(DATA_IMPLEMENTATIONS, {}).get(domain, {}), + ) + + +async def async_get_config_entry_implementation( + hass: HomeAssistant, config_entry: config_entries.ConfigEntry +) -> AbstractOAuth2Implementation: + """Return the implementation for this config entry.""" + implementations = await async_get_implementations(hass, config_entry.domain) + implementation = implementations.get(config_entry.data["auth_implementation"]) + + if implementation is None: + raise ValueError("Implementation not available") + + return implementation + + +class OAuth2AuthorizeCallbackView(HomeAssistantView): + """OAuth2 Authorization Callback View.""" + + requires_auth = False + url = AUTH_CALLBACK_PATH + name = "auth:external:callback" + + async def get(self, request: web.Request) -> web.Response: + """Receive authorization code.""" + if "code" not in request.query or "state" not in request.query: + return web.Response( + text=f"Missing code or state parameter in {request.url}" + ) + + hass = request.app["hass"] + + state = _decode_jwt(hass, request.query["state"]) + + if state is None: + return web.Response(text=f"Invalid state") + + await hass.config_entries.flow.async_configure( + flow_id=state["flow_id"], user_input=request.query["code"] + ) + + return web.Response( + headers={"content-type": "text/html"}, + text="", + ) + + +class OAuth2Session: + """Session to make requests authenticated with OAuth2.""" + + def __init__( + self, + hass: HomeAssistant, + config_entry: config_entries.ConfigEntry, + implementation: AbstractOAuth2Implementation, + ): + """Initialize an OAuth2 session.""" + self.hass = hass + self.config_entry = config_entry + self.implementation = implementation + + async def async_ensure_token_valid(self) -> None: + """Ensure that the current token is valid.""" + token = self.config_entry.data["token"] + + if token["expires_at"] > time.time(): + return + + new_token = await self.implementation.async_refresh_token(token) + + self.hass.config_entries.async_update_entry( # type: ignore + self.config_entry, data={**self.config_entry.data, "token": new_token} + ) + + async def async_request( + self, method: str, url: str, **kwargs: Any + ) -> client.ClientResponse: + """Make a request.""" + await self.async_ensure_token_valid() + return await async_oauth2_request( + self.hass, self.config_entry.data["token"], method, url, **kwargs + ) + + +async def async_oauth2_request( + hass: HomeAssistant, token: dict, method: str, url: str, **kwargs: Any +) -> client.ClientResponse: + """Make an OAuth2 authenticated request. + + This method will not refresh tokens. Use OAuth2 session for that. + """ + session = async_get_clientsession(hass) + + return await session.request( + method, + url, + **kwargs, + headers={ + **kwargs.get("headers", {}), + "authorization": f"Bearer {token['access_token']}", + }, + ) + + +@callback +def _encode_jwt(hass: HomeAssistant, data: dict) -> str: + """JWT encode data.""" + secret = hass.data.get(DATA_JWT_SECRET) + + if secret is None: + secret = hass.data[DATA_JWT_SECRET] = generate_secret() + + return jwt.encode(data, secret, algorithm="HS256").decode() + + +@callback +def _decode_jwt(hass: HomeAssistant, encoded: str) -> Optional[dict]: + """JWT encode data.""" + secret = cast(str, hass.data.get(DATA_JWT_SECRET)) + + try: + return jwt.decode(encoded, secret, algorithms=["HS256"]) + except jwt.InvalidTokenError: + return None diff --git a/requirements_all.txt b/requirements_all.txt index 951c9800943..58a927c81ab 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -1304,7 +1304,7 @@ pymailgunner==1.4 pymediaroom==0.6.4 # homeassistant.components.somfy -pymfy==0.5.2 +pymfy==0.6.0 # homeassistant.components.xiaomi_tv pymitv==1.4.3 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index c9a0013212c..24122915fb5 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -447,7 +447,7 @@ pylitejet==0.1 pymailgunner==1.4 # homeassistant.components.somfy -pymfy==0.5.2 +pymfy==0.6.0 # homeassistant.components.mochad pymochad==0.2.0 diff --git a/tests/common.py b/tests/common.py index 5532e6ccb5c..f40019c5d24 100644 --- a/tests/common.py +++ b/tests/common.py @@ -1015,14 +1015,23 @@ def mock_entity_platform(hass, platform_path, module): hue.light. """ domain, platform_name = platform_path.split(".") - integration_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {}) + mock_platform(hass, f"{platform_name}.{domain}", module) + + +def mock_platform(hass, platform_path, module=None): + """Mock a platform. + + platform_path is in form hue.config_flow. + """ + domain, platform_name = platform_path.split(".") + integration_cache = hass.data.setdefault(loader.DATA_INTEGRATIONS, {}) module_cache = hass.data.setdefault(loader.DATA_COMPONENTS, {}) - if platform_name not in integration_cache: - mock_integration(hass, MockModule(platform_name)) + if domain not in integration_cache: + mock_integration(hass, MockModule(domain)) _LOGGER.info("Adding mock integration platform: %s", platform_path) - module_cache["{}.{}".format(platform_name, domain)] = module + module_cache[platform_path] = module or Mock() def async_capture_events(hass, event_name): diff --git a/tests/components/somfy/test_config_flow.py b/tests/components/somfy/test_config_flow.py index cbc3784e3f5..d42e7b8e367 100644 --- a/tests/components/somfy/test_config_flow.py +++ b/tests/components/somfy/test_config_flow.py @@ -1,19 +1,35 @@ """Tests for the Somfy config flow.""" import asyncio -from unittest.mock import Mock, patch +from unittest.mock import patch -from pymfy.api.somfy_api import SomfyApi +import pytest -from homeassistant import data_entry_flow +from homeassistant import data_entry_flow, setup, config_entries from homeassistant.components.somfy import config_flow, DOMAIN -from homeassistant.components.somfy.config_flow import register_flow_implementation -from tests.common import MockConfigEntry, mock_coro +from homeassistant.helpers import config_entry_oauth2_flow + +from tests.common import MockConfigEntry CLIENT_SECRET_VALUE = "5678" CLIENT_ID_VALUE = "1234" -AUTH_URL = "http://somfy.com" + +@pytest.fixture() +async def mock_impl(hass): + """Mock implementation.""" + await setup.async_setup_component(hass, "http", {}) + + impl = config_entry_oauth2_flow.LocalOAuth2Implementation( + hass, + DOMAIN, + CLIENT_ID_VALUE, + CLIENT_SECRET_VALUE, + "https://accounts.somfy.com/oauth/oauth/v2/auth", + "https://accounts.somfy.com/oauth/oauth/v2/token", + ) + config_flow.SomfyFlowHandler.async_register_implementation(hass, impl) + return impl async def test_abort_if_no_configuration(hass): @@ -30,47 +46,84 @@ async def test_abort_if_existing_entry(hass): flow = config_flow.SomfyFlowHandler() flow.hass = hass MockConfigEntry(domain=DOMAIN).add_to_hass(hass) - result = await flow.async_step_import() - assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT - assert result["reason"] == "already_setup" + result = await flow.async_step_user() assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT assert result["reason"] == "already_setup" -async def test_full_flow(hass): - """Check classic use case.""" - hass.data[DOMAIN] = {} - register_flow_implementation(hass, CLIENT_ID_VALUE, CLIENT_SECRET_VALUE) - flow = config_flow.SomfyFlowHandler() - flow.hass = hass - hass.config.api = Mock(base_url="https://example.com") - flow._get_authorization_url = Mock(return_value=mock_coro((AUTH_URL, "state"))) - result = await flow.async_step_import() +async def test_full_flow(hass, aiohttp_client, aioclient_mock): + """Check full flow.""" + assert await setup.async_setup_component( + hass, + "somfy", + { + "somfy": { + "client_id": CLIENT_ID_VALUE, + "client_secret": CLIENT_SECRET_VALUE, + }, + "http": {"base_url": "https://example.com"}, + }, + ) + + result = await hass.config_entries.flow.async_init( + "somfy", context={"source": config_entries.SOURCE_USER} + ) + state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]}) + assert result["type"] == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP - assert result["url"] == AUTH_URL - result = await flow.async_step_auth("my_super_code") - assert result["type"] == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP_DONE - assert result["step_id"] == "creation" - assert flow.code == "my_super_code" - with patch.object( - SomfyApi, "request_token", return_value={"access_token": "super_token"} - ): - result = await flow.async_step_creation() - assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY - assert result["data"]["refresh_args"] == { - "client_id": CLIENT_ID_VALUE, - "client_secret": CLIENT_SECRET_VALUE, + assert result["url"] == ( + "https://accounts.somfy.com/oauth/oauth/v2/auth" + f"?response_type=code&client_id={CLIENT_ID_VALUE}" + "&redirect_uri=https://example.com/auth/external/callback" + f"&state={state}" + ) + + client = await aiohttp_client(hass.http.app) + resp = await client.get(f"/auth/external/callback?code=abcd&state={state}") + assert resp.status == 200 + assert resp.headers["content-type"] == "text/html; charset=utf-8" + + aioclient_mock.post( + "https://accounts.somfy.com/oauth/oauth/v2/token", + json={ + "refresh_token": "mock-refresh-token", + "access_token": "mock-access-token", + "type": "Bearer", + "expires_in": 60, + }, + ) + + with patch("homeassistant.components.somfy.api.ConfigEntrySomfyApi"): + result = await hass.config_entries.flow.async_configure(result["flow_id"]) + + assert result["data"]["auth_implementation"] == "somfy" + + result["data"]["token"].pop("expires_at") + assert result["data"]["token"] == { + "refresh_token": "mock-refresh-token", + "access_token": "mock-access-token", + "type": "Bearer", + "expires_in": 60, } - assert result["title"] == "Somfy" - assert result["data"]["token"] == {"access_token": "super_token"} + + assert "somfy" in hass.config.components + entry = hass.config_entries.async_entries("somfy")[0] + assert entry.state == config_entries.ENTRY_STATE_LOADED + + assert await hass.config_entries.async_unload(entry.entry_id) + assert entry.state == config_entries.ENTRY_STATE_NOT_LOADED -async def test_abort_if_authorization_timeout(hass): +async def test_abort_if_authorization_timeout(hass, mock_impl): """Check Somfy authorization timeout.""" flow = config_flow.SomfyFlowHandler() flow.hass = hass - flow._get_authorization_url = Mock(side_effect=asyncio.TimeoutError) - result = await flow.async_step_auth() + + with patch.object( + mock_impl, "async_generate_authorize_url", side_effect=asyncio.TimeoutError + ): + result = await flow.async_step_user() + assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT assert result["reason"] == "authorize_url_timeout" diff --git a/tests/helpers/test_config_entry_oauth2_flow.py b/tests/helpers/test_config_entry_oauth2_flow.py new file mode 100644 index 00000000000..e47dd834bf7 --- /dev/null +++ b/tests/helpers/test_config_entry_oauth2_flow.py @@ -0,0 +1,266 @@ +"""Tests for the Somfy config flow.""" +import asyncio +import logging +from unittest.mock import patch +import time + +import pytest + +from homeassistant import data_entry_flow, setup, config_entries +from homeassistant.helpers import config_entry_oauth2_flow + +from tests.common import mock_platform, MockConfigEntry + +TEST_DOMAIN = "oauth2_test" +CLIENT_SECRET = "5678" +CLIENT_ID = "1234" +REFRESH_TOKEN = "mock-refresh-token" +ACCESS_TOKEN_1 = "mock-access-token-1" +ACCESS_TOKEN_2 = "mock-access-token-2" +AUTHORIZE_URL = "https://example.como/auth/authorize" +TOKEN_URL = "https://example.como/auth/token" + + +@pytest.fixture +async def local_impl(hass): + """Local implementation.""" + assert await setup.async_setup_component(hass, "http", {}) + return config_entry_oauth2_flow.LocalOAuth2Implementation( + hass, TEST_DOMAIN, CLIENT_ID, CLIENT_SECRET, AUTHORIZE_URL, TOKEN_URL + ) + + +@pytest.fixture +def flow_handler(hass): + """Return a registered config flow.""" + + mock_platform(hass, f"{TEST_DOMAIN}.config_flow") + + class TestFlowHandler(config_entry_oauth2_flow.AbstractOAuth2FlowHandler): + """Test flow handler.""" + + DOMAIN = TEST_DOMAIN + + @property + def logger(self) -> logging.Logger: + """Return logger.""" + return logging.getLogger(__name__) + + @property + def extra_authorize_data(self) -> dict: + """Extra data that needs to be appended to the authorize url.""" + return {"scope": "read write"} + + with patch.dict(config_entries.HANDLERS, {TEST_DOMAIN: TestFlowHandler}): + yield TestFlowHandler + + +class MockOAuth2Implementation(config_entry_oauth2_flow.AbstractOAuth2Implementation): + """Mock implementation for testing.""" + + @property + def name(self) -> str: + """Name of the implementation.""" + return "Mock" + + @property + def domain(self) -> str: + """Domain that is providing the implementation.""" + return "test" + + async def async_generate_authorize_url(self, flow_id: str) -> str: + """Generate a url for the user to authorize.""" + return "http://example.com/auth" + + async def async_resolve_external_data(self, external_data) -> dict: + """Resolve external data to tokens.""" + return external_data + + async def _async_refresh_token(self, token: dict) -> dict: + """Refresh a token.""" + raise NotImplementedError() + + +def test_inherit_enforces_domain_set(): + """Test we enforce setting DOMAIN.""" + + class TestFlowHandler(config_entry_oauth2_flow.AbstractOAuth2FlowHandler): + """Test flow handler.""" + + @property + def logger(self) -> logging.Logger: + """Return logger.""" + return logging.getLogger(__name__) + + with patch.dict(config_entries.HANDLERS, {TEST_DOMAIN: TestFlowHandler}): + with pytest.raises(TypeError): + TestFlowHandler() + + +async def test_abort_if_no_implementation(hass, flow_handler): + """Check flow abort when no implementations.""" + flow = flow_handler() + flow.hass = hass + result = await flow.async_step_user() + assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT + assert result["reason"] == "missing_configuration" + + +async def test_abort_if_authorization_timeout(hass, flow_handler, local_impl): + """Check timeout generating authorization url.""" + flow_handler.async_register_implementation(hass, local_impl) + + flow = flow_handler() + flow.hass = hass + + with patch.object( + local_impl, "async_generate_authorize_url", side_effect=asyncio.TimeoutError + ): + result = await flow.async_step_user() + + assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT + assert result["reason"] == "authorize_url_timeout" + + +async def test_full_flow( + hass, flow_handler, local_impl, aiohttp_client, aioclient_mock +): + """Check full flow.""" + hass.config.api.base_url = "https://example.com" + flow_handler.async_register_implementation(hass, local_impl) + config_entry_oauth2_flow.async_register_implementation( + hass, TEST_DOMAIN, MockOAuth2Implementation() + ) + + result = await hass.config_entries.flow.async_init( + TEST_DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + + assert result["type"] == data_entry_flow.RESULT_TYPE_FORM + assert result["step_id"] == "pick_implementation" + + # Pick implementation + result = await hass.config_entries.flow.async_configure( + result["flow_id"], user_input={"implementation": TEST_DOMAIN} + ) + + state = config_entry_oauth2_flow._encode_jwt(hass, {"flow_id": result["flow_id"]}) + + assert result["type"] == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP + assert result["url"] == ( + f"{AUTHORIZE_URL}?response_type=code&client_id={CLIENT_ID}" + "&redirect_uri=https://example.com/auth/external/callback" + f"&state={state}&scope=read+write" + ) + + client = await aiohttp_client(hass.http.app) + resp = await client.get(f"/auth/external/callback?code=abcd&state={state}") + assert resp.status == 200 + assert resp.headers["content-type"] == "text/html; charset=utf-8" + + aioclient_mock.post( + TOKEN_URL, + json={ + "refresh_token": REFRESH_TOKEN, + "access_token": ACCESS_TOKEN_1, + "type": "bearer", + "expires_in": 60, + }, + ) + + result = await hass.config_entries.flow.async_configure(result["flow_id"]) + + assert result["data"]["auth_implementation"] == TEST_DOMAIN + + result["data"]["token"].pop("expires_at") + assert result["data"]["token"] == { + "refresh_token": REFRESH_TOKEN, + "access_token": ACCESS_TOKEN_1, + "type": "bearer", + "expires_in": 60, + } + + entry = hass.config_entries.async_entries(TEST_DOMAIN)[0] + + assert ( + await config_entry_oauth2_flow.async_get_config_entry_implementation( + hass, entry + ) + is local_impl + ) + + +async def test_local_refresh_token(hass, local_impl, aioclient_mock): + """Test we can refresh token.""" + aioclient_mock.post( + TOKEN_URL, json={"access_token": ACCESS_TOKEN_2, "expires_in": 100} + ) + + new_tokens = await local_impl.async_refresh_token( + { + "refresh_token": REFRESH_TOKEN, + "access_token": ACCESS_TOKEN_1, + "type": "bearer", + "expires_in": 60, + } + ) + new_tokens.pop("expires_at") + + assert new_tokens == { + "refresh_token": REFRESH_TOKEN, + "access_token": ACCESS_TOKEN_2, + "type": "bearer", + "expires_in": 100, + } + + assert len(aioclient_mock.mock_calls) == 1 + assert aioclient_mock.mock_calls[0][2] == { + "client_id": CLIENT_ID, + "client_secret": CLIENT_SECRET, + "grant_type": "refresh_token", + "refresh_token": REFRESH_TOKEN, + } + + +async def test_oauth_session(hass, flow_handler, local_impl, aioclient_mock): + """Test the OAuth2 session helper.""" + flow_handler.async_register_implementation(hass, local_impl) + + aioclient_mock.post( + TOKEN_URL, json={"access_token": ACCESS_TOKEN_2, "expires_in": 100} + ) + + aioclient_mock.post("https://example.com", status=201) + + config_entry = MockConfigEntry( + domain=TEST_DOMAIN, + data={ + "auth_implementation": TEST_DOMAIN, + "token": { + "refresh_token": REFRESH_TOKEN, + "access_token": ACCESS_TOKEN_1, + "expires_in": 10, + "expires_at": 0, # Forces a refresh, + "token_type": "bearer", + "random_other_data": "should_stay", + }, + }, + ) + + now = time.time() + session = config_entry_oauth2_flow.OAuth2Session(hass, config_entry, local_impl) + resp = await session.async_request("post", "https://example.com") + assert resp.status == 201 + + # Refresh token, make request + assert len(aioclient_mock.mock_calls) == 2 + + assert ( + aioclient_mock.mock_calls[1][3]["authorization"] == f"Bearer {ACCESS_TOKEN_2}" + ) + + assert config_entry.data["token"]["refresh_token"] == REFRESH_TOKEN + assert config_entry.data["token"]["access_token"] == ACCESS_TOKEN_2 + assert config_entry.data["token"]["expires_in"] == 100 + assert config_entry.data["token"]["random_other_data"] == "should_stay" + assert round(config_entry.data["token"]["expires_at"] - now) == 100