From 08cc9fd3754c32cb5dc9f65a4c0ef0accb80700a Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Fri, 25 Oct 2019 16:04:24 -0700 Subject: [PATCH] Add cloud account linking support (#28210) * Add cloud account linking support * Update account_link.py --- homeassistant/bootstrap.py | 2 + homeassistant/components/cloud/__init__.py | 8 +- .../components/cloud/account_link.py | 132 +++++++++++++++ homeassistant/components/cloud/const.py | 1 + homeassistant/components/cloud/manifest.json | 2 +- .../components/somfy/.translations/en.json | 5 + homeassistant/components/somfy/strings.json | 29 ++-- .../helpers/config_entry_oauth2_flow.py | 41 ++++- homeassistant/package_constraints.txt | 2 +- requirements_all.txt | 2 +- requirements_test_all.txt | 2 +- tests/components/cloud/test_account_link.py | 160 ++++++++++++++++++ .../helpers/test_config_entry_oauth2_flow.py | 42 +++++ 13 files changed, 407 insertions(+), 21 deletions(-) create mode 100644 homeassistant/components/cloud/account_link.py create mode 100644 tests/components/cloud/test_account_link.py diff --git a/homeassistant/bootstrap.py b/homeassistant/bootstrap.py index 6118f4f2bd7..312c739cd72 100644 --- a/homeassistant/bootstrap.py +++ b/homeassistant/bootstrap.py @@ -33,6 +33,8 @@ STAGE_1_INTEGRATIONS = { "recorder", # To make sure we forward data to other instances "mqtt_eventstream", + # To provide account link implementations + "cloud", } diff --git a/homeassistant/components/cloud/__init__.py b/homeassistant/components/cloud/__init__.py index a2c79fdc0a7..2d5a2c8b448 100644 --- a/homeassistant/components/cloud/__init__.py +++ b/homeassistant/components/cloud/__init__.py @@ -20,7 +20,7 @@ from homeassistant.helpers import config_validation as cv, entityfilter from homeassistant.loader import bind_hass from homeassistant.util.aiohttp import MockRequest -from . import http_api +from . import account_link, http_api from .client import CloudClient from .const import ( CONF_ACME_DIRECTORY_SERVER, @@ -38,6 +38,7 @@ from .const import ( CONF_REMOTE_API_URL, CONF_SUBSCRIPTION_INFO_URL, CONF_USER_POOL_ID, + CONF_ACCOUNT_LINK_URL, DOMAIN, MODE_DEV, MODE_PROD, @@ -101,6 +102,7 @@ CONFIG_SCHEMA = vol.Schema( vol.Optional(CONF_GOOGLE_ACTIONS): GACTIONS_SCHEMA, vol.Optional(CONF_ALEXA_ACCESS_TOKEN_URL): vol.Url(), vol.Optional(CONF_GOOGLE_ACTIONS_REPORT_STATE_URL): vol.Url(), + vol.Optional(CONF_ACCOUNT_LINK_URL): vol.Url(), } ) }, @@ -168,7 +170,6 @@ def is_cloudhook_request(request): async def async_setup(hass, config): """Initialize the Home Assistant cloud.""" - # Process configs if DOMAIN in config: kwargs = dict(config[DOMAIN]) @@ -248,4 +249,7 @@ async def async_setup(hass, config): cloud.iot.register_on_connect(_on_connect) await http_api.async_setup(hass) + + account_link.async_setup(hass) + return True diff --git a/homeassistant/components/cloud/account_link.py b/homeassistant/components/cloud/account_link.py new file mode 100644 index 00000000000..6fbfcc8723b --- /dev/null +++ b/homeassistant/components/cloud/account_link.py @@ -0,0 +1,132 @@ +"""Account linking via the cloud.""" +import asyncio +import logging +from typing import Any + +from hass_nabucasa import account_link + +from homeassistant.const import MAJOR_VERSION, MINOR_VERSION, PATCH_VERSION +from homeassistant.core import HomeAssistant, callback +from homeassistant.helpers import event, config_entry_oauth2_flow + +from .const import DOMAIN + +DATA_SERVICES = "cloud_account_link_services" +CACHE_TIMEOUT = 3600 +PATCH_VERSION = int(PATCH_VERSION.split(".")[0]) +_LOGGER = logging.getLogger(__name__) + + +@callback +def async_setup(hass: HomeAssistant): + """Set up cloud account link.""" + config_entry_oauth2_flow.async_add_implementation_provider( + hass, DOMAIN, async_provide_implementation + ) + + +async def async_provide_implementation(hass: HomeAssistant, domain: str): + """Provide an implementation for a domain.""" + services = await _get_services(hass) + + for service in services: + if service["service"] == domain and _is_older(service["min_version"]): + return CloudOAuth2Implementation(hass, domain) + + return + + +@callback +def _is_older(version: str) -> bool: + """Test if a version is older than the current HA version.""" + version_parts = version.split(".") + + if len(version_parts) != 3: + return False + + try: + version_parts = [int(val) for val in version_parts] + except ValueError: + return False + + cur_version_parts = [MAJOR_VERSION, MINOR_VERSION, PATCH_VERSION] + + return version_parts <= cur_version_parts + + +async def _get_services(hass): + """Get the available services.""" + services = hass.data.get(DATA_SERVICES) + + if services is not None: + return services + + services = await account_link.async_fetch_available_services(hass.data[DOMAIN]) + + hass.data[DATA_SERVICES] = services + + @callback + def clear_services(_now): + """Clear services cache.""" + hass.data.pop(DATA_SERVICES, None) + + event.async_call_later(hass, CACHE_TIMEOUT, clear_services) + + return services + + +class CloudOAuth2Implementation(config_entry_oauth2_flow.AbstractOAuth2Implementation): + """Cloud implementation of the OAuth2 flow.""" + + def __init__(self, hass: HomeAssistant, service: str): + """Initialize cloud OAuth2 implementation.""" + self.hass = hass + self.service = service + + @property + def name(self) -> str: + """Name of the implementation.""" + return "Home Assistant Cloud" + + @property + def domain(self) -> str: + """Domain that is providing the implementation.""" + return DOMAIN + + async def async_generate_authorize_url(self, flow_id: str) -> str: + """Generate a url for the user to authorize.""" + helper = account_link.AuthorizeAccountHelper( + self.hass.data[DOMAIN], self.service + ) + authorize_url = await helper.async_get_authorize_url() + + async def await_tokens(): + """Wait for tokens and pass them on when received.""" + try: + tokens = await helper.async_get_tokens() + + except asyncio.TimeoutError: + _LOGGER.info("Timeout fetching tokens for flow %s", flow_id) + except account_link.AccountLinkException as err: + _LOGGER.info( + "Failed to fetch tokens for flow %s: %s", flow_id, err.code + ) + else: + await self.hass.config_entries.flow.async_configure( + flow_id=flow_id, user_input=tokens + ) + + self.hass.async_create_task(await_tokens()) + + return authorize_url + + async def async_resolve_external_data(self, external_data: Any) -> dict: + """Resolve external data to tokens.""" + # We already passed in tokens + return external_data + + async def _async_refresh_token(self, token: dict) -> dict: + """Refresh a token.""" + return await account_link.async_fetch_access_token( + self.hass.data[DOMAIN], self.service, token["refresh_token"] + ) diff --git a/homeassistant/components/cloud/const.py b/homeassistant/components/cloud/const.py index 6495cba23b7..262f84a85e6 100644 --- a/homeassistant/components/cloud/const.py +++ b/homeassistant/components/cloud/const.py @@ -37,6 +37,7 @@ CONF_REMOTE_API_URL = "remote_api_url" CONF_ACME_DIRECTORY_SERVER = "acme_directory_server" CONF_ALEXA_ACCESS_TOKEN_URL = "alexa_access_token_url" CONF_GOOGLE_ACTIONS_REPORT_STATE_URL = "google_actions_report_state_url" +CONF_ACCOUNT_LINK_URL = "account_link_url" MODE_DEV = "development" MODE_PROD = "production" diff --git a/homeassistant/components/cloud/manifest.json b/homeassistant/components/cloud/manifest.json index c8fa6884563..9e9b77287ae 100644 --- a/homeassistant/components/cloud/manifest.json +++ b/homeassistant/components/cloud/manifest.json @@ -2,7 +2,7 @@ "domain": "cloud", "name": "Cloud", "documentation": "https://www.home-assistant.io/integrations/cloud", - "requirements": ["hass-nabucasa==0.22"], + "requirements": ["hass-nabucasa==0.23"], "dependencies": ["http", "webhook"], "codeowners": ["@home-assistant/cloud"] } diff --git a/homeassistant/components/somfy/.translations/en.json b/homeassistant/components/somfy/.translations/en.json index d4155915636..3b2f2e6beaf 100644 --- a/homeassistant/components/somfy/.translations/en.json +++ b/homeassistant/components/somfy/.translations/en.json @@ -8,6 +8,11 @@ "create_entry": { "default": "Successfully authenticated with Somfy." }, + "step": { + "pick_implementation": { + "title": "Pick Authentication Method" + } + }, "title": "Somfy" } } \ No newline at end of file diff --git a/homeassistant/components/somfy/strings.json b/homeassistant/components/somfy/strings.json index d4155915636..81308ba18af 100644 --- a/homeassistant/components/somfy/strings.json +++ b/homeassistant/components/somfy/strings.json @@ -1,13 +1,18 @@ { - "config": { - "abort": { - "already_setup": "You can only configure one Somfy account.", - "authorize_url_timeout": "Timeout generating authorize url.", - "missing_configuration": "The Somfy component is not configured. Please follow the documentation." - }, - "create_entry": { - "default": "Successfully authenticated with Somfy." - }, - "title": "Somfy" - } -} \ No newline at end of file + "config": { + "step": { + "pick_implementation": { + "title": "Pick Authentication Method" + } + }, + "abort": { + "already_setup": "You can only configure one Somfy account.", + "authorize_url_timeout": "Timeout generating authorize url.", + "missing_configuration": "The Somfy component is not configured. Please follow the documentation." + }, + "create_entry": { + "default": "Successfully authenticated with Somfy." + }, + "title": "Somfy" + } +} diff --git a/homeassistant/helpers/config_entry_oauth2_flow.py b/homeassistant/helpers/config_entry_oauth2_flow.py index 7fb954378ee..d3db8febcb2 100644 --- a/homeassistant/helpers/config_entry_oauth2_flow.py +++ b/homeassistant/helpers/config_entry_oauth2_flow.py @@ -8,7 +8,7 @@ This module exists of the following parts: import asyncio from abc import ABCMeta, ABC, abstractmethod import logging -from typing import Optional, Any, Dict, cast +from typing import Optional, Any, Dict, cast, Awaitable, Callable import time import async_timeout @@ -28,6 +28,7 @@ from .aiohttp_client import async_get_clientsession DATA_JWT_SECRET = "oauth2_jwt_secret" DATA_VIEW_REGISTERED = "oauth2_view_reg" DATA_IMPLEMENTATIONS = "oauth2_impl" +DATA_PROVIDERS = "oauth2_providers" AUTH_CALLBACK_PATH = "/auth/external/callback" @@ -291,11 +292,23 @@ async def async_get_implementations( hass: HomeAssistant, domain: str ) -> Dict[str, AbstractOAuth2Implementation]: """Return OAuth2 implementations for specified domain.""" - return cast( + registered = cast( Dict[str, AbstractOAuth2Implementation], hass.data.setdefault(DATA_IMPLEMENTATIONS, {}).get(domain, {}), ) + if DATA_PROVIDERS not in hass.data: + return registered + + registered = dict(registered) + + for provider_domain, get_impl in hass.data[DATA_PROVIDERS].items(): + implementation = await get_impl(hass, domain) + if implementation is not None: + registered[provider_domain] = implementation + + return registered + async def async_get_config_entry_implementation( hass: HomeAssistant, config_entry: config_entries.ConfigEntry @@ -310,6 +323,23 @@ async def async_get_config_entry_implementation( return implementation +@callback +def async_add_implementation_provider( + hass: HomeAssistant, + provider_domain: str, + async_provide_implementation: Callable[ + [HomeAssistant, str], Awaitable[Optional[AbstractOAuth2Implementation]] + ], +) -> None: + """Add an implementation provider. + + If no implementation found, return None. + """ + hass.data.setdefault(DATA_PROVIDERS, {})[ + provider_domain + ] = async_provide_implementation + + class OAuth2AuthorizeCallbackView(HomeAssistantView): """OAuth2 Authorization Callback View.""" @@ -355,9 +385,14 @@ class OAuth2Session: self.config_entry = config_entry self.implementation = implementation + @property + def token(self) -> dict: + """Return the current token.""" + return cast(dict, self.config_entry.data["token"]) + async def async_ensure_token_valid(self) -> None: """Ensure that the current token is valid.""" - token = self.config_entry.data["token"] + token = self.token if token["expires_at"] > time.time(): return diff --git a/homeassistant/package_constraints.txt b/homeassistant/package_constraints.txt index fbfa1dbf67b..87878b49615 100644 --- a/homeassistant/package_constraints.txt +++ b/homeassistant/package_constraints.txt @@ -10,7 +10,7 @@ certifi>=2019.9.11 contextvars==2.4;python_version<"3.7" cryptography==2.8 distro==1.4.0 -hass-nabucasa==0.22 +hass-nabucasa==0.23 home-assistant-frontend==20191025.0 importlib-metadata==0.23 jinja2>=2.10.1 diff --git a/requirements_all.txt b/requirements_all.txt index 8e8a72d0181..1039b53f4f3 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -616,7 +616,7 @@ habitipy==0.2.0 hangups==0.4.9 # homeassistant.components.cloud -hass-nabucasa==0.22 +hass-nabucasa==0.23 # homeassistant.components.mqtt hbmqtt==0.9.5 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index 8f55cb1aa1b..db5c1a491cb 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -225,7 +225,7 @@ ha-ffmpeg==2.0 hangups==0.4.9 # homeassistant.components.cloud -hass-nabucasa==0.22 +hass-nabucasa==0.23 # homeassistant.components.mqtt hbmqtt==0.9.5 diff --git a/tests/components/cloud/test_account_link.py b/tests/components/cloud/test_account_link.py new file mode 100644 index 00000000000..60116895beb --- /dev/null +++ b/tests/components/cloud/test_account_link.py @@ -0,0 +1,160 @@ +"""Test account link services.""" +import asyncio +import logging +from time import time +from unittest.mock import Mock, patch + +import pytest + +from homeassistant import data_entry_flow, config_entries +from homeassistant.helpers import config_entry_oauth2_flow +from homeassistant.components.cloud import account_link +from homeassistant.util.dt import utcnow +from tests.common import mock_coro, async_fire_time_changed, mock_platform + + +TEST_DOMAIN = "oauth2_test" + + +@pytest.fixture +def flow_handler(hass): + """Return a registered config flow.""" + + mock_platform(hass, f"{TEST_DOMAIN}.config_flow") + + class TestFlowHandler(config_entry_oauth2_flow.AbstractOAuth2FlowHandler): + """Test flow handler.""" + + DOMAIN = TEST_DOMAIN + + @property + def logger(self) -> logging.Logger: + """Return logger.""" + return logging.getLogger(__name__) + + with patch.dict(config_entries.HANDLERS, {TEST_DOMAIN: TestFlowHandler}): + yield TestFlowHandler + + +async def test_setup_provide_implementation(hass): + """Test that we provide implementations.""" + account_link.async_setup(hass) + + with patch( + "homeassistant.components.cloud.account_link._get_services", + side_effect=lambda _: mock_coro( + [ + {"service": "test", "min_version": "0.1.0"}, + {"service": "too_new", "min_version": "100.0.0"}, + ] + ), + ): + assert ( + await config_entry_oauth2_flow.async_get_implementations( + hass, "non_existing" + ) + == {} + ) + assert ( + await config_entry_oauth2_flow.async_get_implementations(hass, "too_new") + == {} + ) + implementations = await config_entry_oauth2_flow.async_get_implementations( + hass, "test" + ) + + assert "cloud" in implementations + assert implementations["cloud"].domain == "cloud" + assert implementations["cloud"].service == "test" + assert implementations["cloud"].hass is hass + + +async def test_get_services_cached(hass): + """Test that we cache services.""" + hass.data["cloud"] = None + + services = 1 + + with patch.object(account_link, "CACHE_TIMEOUT", 0), patch( + "hass_nabucasa.account_link.async_fetch_available_services", + side_effect=lambda _: mock_coro(services), + ) as mock_fetch: + assert await account_link._get_services(hass) == 1 + + services = 2 + + assert len(mock_fetch.mock_calls) == 1 + assert await account_link._get_services(hass) == 1 + + services = 3 + hass.data.pop(account_link.DATA_SERVICES) + assert await account_link._get_services(hass) == 3 + + services = 4 + async_fire_time_changed(hass, utcnow()) + await hass.async_block_till_done() + + # Check cache purged + assert await account_link._get_services(hass) == 4 + + +async def test_implementation(hass, flow_handler): + """Test Cloud OAuth2 implementation.""" + hass.data["cloud"] = None + + impl = account_link.CloudOAuth2Implementation(hass, "test") + assert impl.name == "Home Assistant Cloud" + assert impl.domain == "cloud" + + flow_handler.async_register_implementation(hass, impl) + + flow_finished = asyncio.Future() + + helper = Mock( + async_get_authorize_url=Mock(return_value=mock_coro("http://example.com/auth")), + async_get_tokens=Mock(return_value=flow_finished), + ) + + with patch( + "hass_nabucasa.account_link.AuthorizeAccountHelper", return_value=helper + ): + result = await hass.config_entries.flow.async_init( + TEST_DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + + assert result["type"] == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP + assert result["url"] == "http://example.com/auth" + + flow_finished.set_result( + { + "refresh_token": "mock-refresh", + "access_token": "mock-access", + "expires_in": 10, + "token_type": "bearer", + } + ) + await hass.async_block_till_done() + + # Flow finished! + result = await hass.config_entries.flow.async_configure(result["flow_id"]) + + assert result["data"]["auth_implementation"] == "cloud" + + expires_at = result["data"]["token"].pop("expires_at") + assert round(expires_at - time()) == 10 + + assert result["data"]["token"] == { + "refresh_token": "mock-refresh", + "access_token": "mock-access", + "token_type": "bearer", + "expires_in": 10, + } + + entry = hass.config_entries.async_entries(TEST_DOMAIN)[0] + + assert ( + await config_entry_oauth2_flow.async_get_config_entry_implementation( + hass, entry + ) + is impl + ) diff --git a/tests/helpers/test_config_entry_oauth2_flow.py b/tests/helpers/test_config_entry_oauth2_flow.py index e47dd834bf7..773dfa09375 100644 --- a/tests/helpers/test_config_entry_oauth2_flow.py +++ b/tests/helpers/test_config_entry_oauth2_flow.py @@ -264,3 +264,45 @@ async def test_oauth_session(hass, flow_handler, local_impl, aioclient_mock): assert config_entry.data["token"]["expires_in"] == 100 assert config_entry.data["token"]["random_other_data"] == "should_stay" assert round(config_entry.data["token"]["expires_at"] - now) == 100 + + +async def test_implementation_provider(hass, local_impl): + """Test providing an implementation provider.""" + assert ( + await config_entry_oauth2_flow.async_get_implementations(hass, TEST_DOMAIN) + == {} + ) + + mock_domain_with_impl = "some_domain" + + config_entry_oauth2_flow.async_register_implementation( + hass, mock_domain_with_impl, local_impl + ) + + assert await config_entry_oauth2_flow.async_get_implementations( + hass, mock_domain_with_impl + ) == {TEST_DOMAIN: local_impl} + + provider_source = {} + + async def async_provide_implementation(hass, domain): + """Mock implementation provider.""" + return provider_source.get(domain) + + config_entry_oauth2_flow.async_add_implementation_provider( + hass, "cloud", async_provide_implementation + ) + + assert await config_entry_oauth2_flow.async_get_implementations( + hass, mock_domain_with_impl + ) == {TEST_DOMAIN: local_impl} + + provider_source[ + mock_domain_with_impl + ] = config_entry_oauth2_flow.LocalOAuth2Implementation( + hass, "cloud", CLIENT_ID, CLIENT_SECRET, AUTHORIZE_URL, TOKEN_URL + ) + + assert await config_entry_oauth2_flow.async_get_implementations( + hass, mock_domain_with_impl + ) == {TEST_DOMAIN: local_impl, "cloud": provider_source[mock_domain_with_impl]}