Add cloud account linking support (#28210)

* Add cloud account linking support

* Update account_link.py
This commit is contained in:
Paulus Schoutsen 2019-10-25 16:04:24 -07:00 committed by GitHub
parent 475b43500a
commit 08cc9fd375
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 407 additions and 21 deletions

View File

@ -33,6 +33,8 @@ STAGE_1_INTEGRATIONS = {
"recorder", "recorder",
# To make sure we forward data to other instances # To make sure we forward data to other instances
"mqtt_eventstream", "mqtt_eventstream",
# To provide account link implementations
"cloud",
} }

View File

@ -20,7 +20,7 @@ from homeassistant.helpers import config_validation as cv, entityfilter
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
from homeassistant.util.aiohttp import MockRequest from homeassistant.util.aiohttp import MockRequest
from . import http_api from . import account_link, http_api
from .client import CloudClient from .client import CloudClient
from .const import ( from .const import (
CONF_ACME_DIRECTORY_SERVER, CONF_ACME_DIRECTORY_SERVER,
@ -38,6 +38,7 @@ from .const import (
CONF_REMOTE_API_URL, CONF_REMOTE_API_URL,
CONF_SUBSCRIPTION_INFO_URL, CONF_SUBSCRIPTION_INFO_URL,
CONF_USER_POOL_ID, CONF_USER_POOL_ID,
CONF_ACCOUNT_LINK_URL,
DOMAIN, DOMAIN,
MODE_DEV, MODE_DEV,
MODE_PROD, MODE_PROD,
@ -101,6 +102,7 @@ CONFIG_SCHEMA = vol.Schema(
vol.Optional(CONF_GOOGLE_ACTIONS): GACTIONS_SCHEMA, vol.Optional(CONF_GOOGLE_ACTIONS): GACTIONS_SCHEMA,
vol.Optional(CONF_ALEXA_ACCESS_TOKEN_URL): vol.Url(), vol.Optional(CONF_ALEXA_ACCESS_TOKEN_URL): vol.Url(),
vol.Optional(CONF_GOOGLE_ACTIONS_REPORT_STATE_URL): vol.Url(), vol.Optional(CONF_GOOGLE_ACTIONS_REPORT_STATE_URL): vol.Url(),
vol.Optional(CONF_ACCOUNT_LINK_URL): vol.Url(),
} }
) )
}, },
@ -168,7 +170,6 @@ def is_cloudhook_request(request):
async def async_setup(hass, config): async def async_setup(hass, config):
"""Initialize the Home Assistant cloud.""" """Initialize the Home Assistant cloud."""
# Process configs # Process configs
if DOMAIN in config: if DOMAIN in config:
kwargs = dict(config[DOMAIN]) kwargs = dict(config[DOMAIN])
@ -248,4 +249,7 @@ async def async_setup(hass, config):
cloud.iot.register_on_connect(_on_connect) cloud.iot.register_on_connect(_on_connect)
await http_api.async_setup(hass) await http_api.async_setup(hass)
account_link.async_setup(hass)
return True return True

View File

@ -0,0 +1,132 @@
"""Account linking via the cloud."""
import asyncio
import logging
from typing import Any
from hass_nabucasa import account_link
from homeassistant.const import MAJOR_VERSION, MINOR_VERSION, PATCH_VERSION
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers import event, config_entry_oauth2_flow
from .const import DOMAIN
DATA_SERVICES = "cloud_account_link_services"
CACHE_TIMEOUT = 3600
PATCH_VERSION = int(PATCH_VERSION.split(".")[0])
_LOGGER = logging.getLogger(__name__)
@callback
def async_setup(hass: HomeAssistant):
"""Set up cloud account link."""
config_entry_oauth2_flow.async_add_implementation_provider(
hass, DOMAIN, async_provide_implementation
)
async def async_provide_implementation(hass: HomeAssistant, domain: str):
"""Provide an implementation for a domain."""
services = await _get_services(hass)
for service in services:
if service["service"] == domain and _is_older(service["min_version"]):
return CloudOAuth2Implementation(hass, domain)
return
@callback
def _is_older(version: str) -> bool:
"""Test if a version is older than the current HA version."""
version_parts = version.split(".")
if len(version_parts) != 3:
return False
try:
version_parts = [int(val) for val in version_parts]
except ValueError:
return False
cur_version_parts = [MAJOR_VERSION, MINOR_VERSION, PATCH_VERSION]
return version_parts <= cur_version_parts
async def _get_services(hass):
"""Get the available services."""
services = hass.data.get(DATA_SERVICES)
if services is not None:
return services
services = await account_link.async_fetch_available_services(hass.data[DOMAIN])
hass.data[DATA_SERVICES] = services
@callback
def clear_services(_now):
"""Clear services cache."""
hass.data.pop(DATA_SERVICES, None)
event.async_call_later(hass, CACHE_TIMEOUT, clear_services)
return services
class CloudOAuth2Implementation(config_entry_oauth2_flow.AbstractOAuth2Implementation):
"""Cloud implementation of the OAuth2 flow."""
def __init__(self, hass: HomeAssistant, service: str):
"""Initialize cloud OAuth2 implementation."""
self.hass = hass
self.service = service
@property
def name(self) -> str:
"""Name of the implementation."""
return "Home Assistant Cloud"
@property
def domain(self) -> str:
"""Domain that is providing the implementation."""
return DOMAIN
async def async_generate_authorize_url(self, flow_id: str) -> str:
"""Generate a url for the user to authorize."""
helper = account_link.AuthorizeAccountHelper(
self.hass.data[DOMAIN], self.service
)
authorize_url = await helper.async_get_authorize_url()
async def await_tokens():
"""Wait for tokens and pass them on when received."""
try:
tokens = await helper.async_get_tokens()
except asyncio.TimeoutError:
_LOGGER.info("Timeout fetching tokens for flow %s", flow_id)
except account_link.AccountLinkException as err:
_LOGGER.info(
"Failed to fetch tokens for flow %s: %s", flow_id, err.code
)
else:
await self.hass.config_entries.flow.async_configure(
flow_id=flow_id, user_input=tokens
)
self.hass.async_create_task(await_tokens())
return authorize_url
async def async_resolve_external_data(self, external_data: Any) -> dict:
"""Resolve external data to tokens."""
# We already passed in tokens
return external_data
async def _async_refresh_token(self, token: dict) -> dict:
"""Refresh a token."""
return await account_link.async_fetch_access_token(
self.hass.data[DOMAIN], self.service, token["refresh_token"]
)

View File

@ -37,6 +37,7 @@ CONF_REMOTE_API_URL = "remote_api_url"
CONF_ACME_DIRECTORY_SERVER = "acme_directory_server" CONF_ACME_DIRECTORY_SERVER = "acme_directory_server"
CONF_ALEXA_ACCESS_TOKEN_URL = "alexa_access_token_url" CONF_ALEXA_ACCESS_TOKEN_URL = "alexa_access_token_url"
CONF_GOOGLE_ACTIONS_REPORT_STATE_URL = "google_actions_report_state_url" CONF_GOOGLE_ACTIONS_REPORT_STATE_URL = "google_actions_report_state_url"
CONF_ACCOUNT_LINK_URL = "account_link_url"
MODE_DEV = "development" MODE_DEV = "development"
MODE_PROD = "production" MODE_PROD = "production"

View File

@ -2,7 +2,7 @@
"domain": "cloud", "domain": "cloud",
"name": "Cloud", "name": "Cloud",
"documentation": "https://www.home-assistant.io/integrations/cloud", "documentation": "https://www.home-assistant.io/integrations/cloud",
"requirements": ["hass-nabucasa==0.22"], "requirements": ["hass-nabucasa==0.23"],
"dependencies": ["http", "webhook"], "dependencies": ["http", "webhook"],
"codeowners": ["@home-assistant/cloud"] "codeowners": ["@home-assistant/cloud"]
} }

View File

@ -8,6 +8,11 @@
"create_entry": { "create_entry": {
"default": "Successfully authenticated with Somfy." "default": "Successfully authenticated with Somfy."
}, },
"step": {
"pick_implementation": {
"title": "Pick Authentication Method"
}
},
"title": "Somfy" "title": "Somfy"
} }
} }

View File

@ -1,13 +1,18 @@
{ {
"config": { "config": {
"abort": { "step": {
"already_setup": "You can only configure one Somfy account.", "pick_implementation": {
"authorize_url_timeout": "Timeout generating authorize url.", "title": "Pick Authentication Method"
"missing_configuration": "The Somfy component is not configured. Please follow the documentation." }
}, },
"create_entry": { "abort": {
"default": "Successfully authenticated with Somfy." "already_setup": "You can only configure one Somfy account.",
}, "authorize_url_timeout": "Timeout generating authorize url.",
"title": "Somfy" "missing_configuration": "The Somfy component is not configured. Please follow the documentation."
} },
"create_entry": {
"default": "Successfully authenticated with Somfy."
},
"title": "Somfy"
}
} }

View File

@ -8,7 +8,7 @@ This module exists of the following parts:
import asyncio import asyncio
from abc import ABCMeta, ABC, abstractmethod from abc import ABCMeta, ABC, abstractmethod
import logging import logging
from typing import Optional, Any, Dict, cast from typing import Optional, Any, Dict, cast, Awaitable, Callable
import time import time
import async_timeout import async_timeout
@ -28,6 +28,7 @@ from .aiohttp_client import async_get_clientsession
DATA_JWT_SECRET = "oauth2_jwt_secret" DATA_JWT_SECRET = "oauth2_jwt_secret"
DATA_VIEW_REGISTERED = "oauth2_view_reg" DATA_VIEW_REGISTERED = "oauth2_view_reg"
DATA_IMPLEMENTATIONS = "oauth2_impl" DATA_IMPLEMENTATIONS = "oauth2_impl"
DATA_PROVIDERS = "oauth2_providers"
AUTH_CALLBACK_PATH = "/auth/external/callback" AUTH_CALLBACK_PATH = "/auth/external/callback"
@ -291,11 +292,23 @@ async def async_get_implementations(
hass: HomeAssistant, domain: str hass: HomeAssistant, domain: str
) -> Dict[str, AbstractOAuth2Implementation]: ) -> Dict[str, AbstractOAuth2Implementation]:
"""Return OAuth2 implementations for specified domain.""" """Return OAuth2 implementations for specified domain."""
return cast( registered = cast(
Dict[str, AbstractOAuth2Implementation], Dict[str, AbstractOAuth2Implementation],
hass.data.setdefault(DATA_IMPLEMENTATIONS, {}).get(domain, {}), hass.data.setdefault(DATA_IMPLEMENTATIONS, {}).get(domain, {}),
) )
if DATA_PROVIDERS not in hass.data:
return registered
registered = dict(registered)
for provider_domain, get_impl in hass.data[DATA_PROVIDERS].items():
implementation = await get_impl(hass, domain)
if implementation is not None:
registered[provider_domain] = implementation
return registered
async def async_get_config_entry_implementation( async def async_get_config_entry_implementation(
hass: HomeAssistant, config_entry: config_entries.ConfigEntry hass: HomeAssistant, config_entry: config_entries.ConfigEntry
@ -310,6 +323,23 @@ async def async_get_config_entry_implementation(
return implementation return implementation
@callback
def async_add_implementation_provider(
hass: HomeAssistant,
provider_domain: str,
async_provide_implementation: Callable[
[HomeAssistant, str], Awaitable[Optional[AbstractOAuth2Implementation]]
],
) -> None:
"""Add an implementation provider.
If no implementation found, return None.
"""
hass.data.setdefault(DATA_PROVIDERS, {})[
provider_domain
] = async_provide_implementation
class OAuth2AuthorizeCallbackView(HomeAssistantView): class OAuth2AuthorizeCallbackView(HomeAssistantView):
"""OAuth2 Authorization Callback View.""" """OAuth2 Authorization Callback View."""
@ -355,9 +385,14 @@ class OAuth2Session:
self.config_entry = config_entry self.config_entry = config_entry
self.implementation = implementation self.implementation = implementation
@property
def token(self) -> dict:
"""Return the current token."""
return cast(dict, self.config_entry.data["token"])
async def async_ensure_token_valid(self) -> None: async def async_ensure_token_valid(self) -> None:
"""Ensure that the current token is valid.""" """Ensure that the current token is valid."""
token = self.config_entry.data["token"] token = self.token
if token["expires_at"] > time.time(): if token["expires_at"] > time.time():
return return

View File

@ -10,7 +10,7 @@ certifi>=2019.9.11
contextvars==2.4;python_version<"3.7" contextvars==2.4;python_version<"3.7"
cryptography==2.8 cryptography==2.8
distro==1.4.0 distro==1.4.0
hass-nabucasa==0.22 hass-nabucasa==0.23
home-assistant-frontend==20191025.0 home-assistant-frontend==20191025.0
importlib-metadata==0.23 importlib-metadata==0.23
jinja2>=2.10.1 jinja2>=2.10.1

View File

@ -616,7 +616,7 @@ habitipy==0.2.0
hangups==0.4.9 hangups==0.4.9
# homeassistant.components.cloud # homeassistant.components.cloud
hass-nabucasa==0.22 hass-nabucasa==0.23
# homeassistant.components.mqtt # homeassistant.components.mqtt
hbmqtt==0.9.5 hbmqtt==0.9.5

View File

@ -225,7 +225,7 @@ ha-ffmpeg==2.0
hangups==0.4.9 hangups==0.4.9
# homeassistant.components.cloud # homeassistant.components.cloud
hass-nabucasa==0.22 hass-nabucasa==0.23
# homeassistant.components.mqtt # homeassistant.components.mqtt
hbmqtt==0.9.5 hbmqtt==0.9.5

View File

@ -0,0 +1,160 @@
"""Test account link services."""
import asyncio
import logging
from time import time
from unittest.mock import Mock, patch
import pytest
from homeassistant import data_entry_flow, config_entries
from homeassistant.helpers import config_entry_oauth2_flow
from homeassistant.components.cloud import account_link
from homeassistant.util.dt import utcnow
from tests.common import mock_coro, async_fire_time_changed, mock_platform
TEST_DOMAIN = "oauth2_test"
@pytest.fixture
def flow_handler(hass):
"""Return a registered config flow."""
mock_platform(hass, f"{TEST_DOMAIN}.config_flow")
class TestFlowHandler(config_entry_oauth2_flow.AbstractOAuth2FlowHandler):
"""Test flow handler."""
DOMAIN = TEST_DOMAIN
@property
def logger(self) -> logging.Logger:
"""Return logger."""
return logging.getLogger(__name__)
with patch.dict(config_entries.HANDLERS, {TEST_DOMAIN: TestFlowHandler}):
yield TestFlowHandler
async def test_setup_provide_implementation(hass):
"""Test that we provide implementations."""
account_link.async_setup(hass)
with patch(
"homeassistant.components.cloud.account_link._get_services",
side_effect=lambda _: mock_coro(
[
{"service": "test", "min_version": "0.1.0"},
{"service": "too_new", "min_version": "100.0.0"},
]
),
):
assert (
await config_entry_oauth2_flow.async_get_implementations(
hass, "non_existing"
)
== {}
)
assert (
await config_entry_oauth2_flow.async_get_implementations(hass, "too_new")
== {}
)
implementations = await config_entry_oauth2_flow.async_get_implementations(
hass, "test"
)
assert "cloud" in implementations
assert implementations["cloud"].domain == "cloud"
assert implementations["cloud"].service == "test"
assert implementations["cloud"].hass is hass
async def test_get_services_cached(hass):
"""Test that we cache services."""
hass.data["cloud"] = None
services = 1
with patch.object(account_link, "CACHE_TIMEOUT", 0), patch(
"hass_nabucasa.account_link.async_fetch_available_services",
side_effect=lambda _: mock_coro(services),
) as mock_fetch:
assert await account_link._get_services(hass) == 1
services = 2
assert len(mock_fetch.mock_calls) == 1
assert await account_link._get_services(hass) == 1
services = 3
hass.data.pop(account_link.DATA_SERVICES)
assert await account_link._get_services(hass) == 3
services = 4
async_fire_time_changed(hass, utcnow())
await hass.async_block_till_done()
# Check cache purged
assert await account_link._get_services(hass) == 4
async def test_implementation(hass, flow_handler):
"""Test Cloud OAuth2 implementation."""
hass.data["cloud"] = None
impl = account_link.CloudOAuth2Implementation(hass, "test")
assert impl.name == "Home Assistant Cloud"
assert impl.domain == "cloud"
flow_handler.async_register_implementation(hass, impl)
flow_finished = asyncio.Future()
helper = Mock(
async_get_authorize_url=Mock(return_value=mock_coro("http://example.com/auth")),
async_get_tokens=Mock(return_value=flow_finished),
)
with patch(
"hass_nabucasa.account_link.AuthorizeAccountHelper", return_value=helper
):
result = await hass.config_entries.flow.async_init(
TEST_DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result["type"] == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP
assert result["url"] == "http://example.com/auth"
flow_finished.set_result(
{
"refresh_token": "mock-refresh",
"access_token": "mock-access",
"expires_in": 10,
"token_type": "bearer",
}
)
await hass.async_block_till_done()
# Flow finished!
result = await hass.config_entries.flow.async_configure(result["flow_id"])
assert result["data"]["auth_implementation"] == "cloud"
expires_at = result["data"]["token"].pop("expires_at")
assert round(expires_at - time()) == 10
assert result["data"]["token"] == {
"refresh_token": "mock-refresh",
"access_token": "mock-access",
"token_type": "bearer",
"expires_in": 10,
}
entry = hass.config_entries.async_entries(TEST_DOMAIN)[0]
assert (
await config_entry_oauth2_flow.async_get_config_entry_implementation(
hass, entry
)
is impl
)

View File

@ -264,3 +264,45 @@ async def test_oauth_session(hass, flow_handler, local_impl, aioclient_mock):
assert config_entry.data["token"]["expires_in"] == 100 assert config_entry.data["token"]["expires_in"] == 100
assert config_entry.data["token"]["random_other_data"] == "should_stay" assert config_entry.data["token"]["random_other_data"] == "should_stay"
assert round(config_entry.data["token"]["expires_at"] - now) == 100 assert round(config_entry.data["token"]["expires_at"] - now) == 100
async def test_implementation_provider(hass, local_impl):
"""Test providing an implementation provider."""
assert (
await config_entry_oauth2_flow.async_get_implementations(hass, TEST_DOMAIN)
== {}
)
mock_domain_with_impl = "some_domain"
config_entry_oauth2_flow.async_register_implementation(
hass, mock_domain_with_impl, local_impl
)
assert await config_entry_oauth2_flow.async_get_implementations(
hass, mock_domain_with_impl
) == {TEST_DOMAIN: local_impl}
provider_source = {}
async def async_provide_implementation(hass, domain):
"""Mock implementation provider."""
return provider_source.get(domain)
config_entry_oauth2_flow.async_add_implementation_provider(
hass, "cloud", async_provide_implementation
)
assert await config_entry_oauth2_flow.async_get_implementations(
hass, mock_domain_with_impl
) == {TEST_DOMAIN: local_impl}
provider_source[
mock_domain_with_impl
] = config_entry_oauth2_flow.LocalOAuth2Implementation(
hass, "cloud", CLIENT_ID, CLIENT_SECRET, AUTHORIZE_URL, TOKEN_URL
)
assert await config_entry_oauth2_flow.async_get_implementations(
hass, mock_domain_with_impl
) == {TEST_DOMAIN: local_impl, "cloud": provider_source[mock_domain_with_impl]}