mirror of
https://github.com/home-assistant/core.git
synced 2025-07-25 06:07:17 +00:00
Add cloud account linking support (#28210)
* Add cloud account linking support * Update account_link.py
This commit is contained in:
parent
475b43500a
commit
08cc9fd375
@ -33,6 +33,8 @@ STAGE_1_INTEGRATIONS = {
|
|||||||
"recorder",
|
"recorder",
|
||||||
# To make sure we forward data to other instances
|
# To make sure we forward data to other instances
|
||||||
"mqtt_eventstream",
|
"mqtt_eventstream",
|
||||||
|
# To provide account link implementations
|
||||||
|
"cloud",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ from homeassistant.helpers import config_validation as cv, entityfilter
|
|||||||
from homeassistant.loader import bind_hass
|
from homeassistant.loader import bind_hass
|
||||||
from homeassistant.util.aiohttp import MockRequest
|
from homeassistant.util.aiohttp import MockRequest
|
||||||
|
|
||||||
from . import http_api
|
from . import account_link, http_api
|
||||||
from .client import CloudClient
|
from .client import CloudClient
|
||||||
from .const import (
|
from .const import (
|
||||||
CONF_ACME_DIRECTORY_SERVER,
|
CONF_ACME_DIRECTORY_SERVER,
|
||||||
@ -38,6 +38,7 @@ from .const import (
|
|||||||
CONF_REMOTE_API_URL,
|
CONF_REMOTE_API_URL,
|
||||||
CONF_SUBSCRIPTION_INFO_URL,
|
CONF_SUBSCRIPTION_INFO_URL,
|
||||||
CONF_USER_POOL_ID,
|
CONF_USER_POOL_ID,
|
||||||
|
CONF_ACCOUNT_LINK_URL,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
MODE_DEV,
|
MODE_DEV,
|
||||||
MODE_PROD,
|
MODE_PROD,
|
||||||
@ -101,6 +102,7 @@ CONFIG_SCHEMA = vol.Schema(
|
|||||||
vol.Optional(CONF_GOOGLE_ACTIONS): GACTIONS_SCHEMA,
|
vol.Optional(CONF_GOOGLE_ACTIONS): GACTIONS_SCHEMA,
|
||||||
vol.Optional(CONF_ALEXA_ACCESS_TOKEN_URL): vol.Url(),
|
vol.Optional(CONF_ALEXA_ACCESS_TOKEN_URL): vol.Url(),
|
||||||
vol.Optional(CONF_GOOGLE_ACTIONS_REPORT_STATE_URL): vol.Url(),
|
vol.Optional(CONF_GOOGLE_ACTIONS_REPORT_STATE_URL): vol.Url(),
|
||||||
|
vol.Optional(CONF_ACCOUNT_LINK_URL): vol.Url(),
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
},
|
},
|
||||||
@ -168,7 +170,6 @@ def is_cloudhook_request(request):
|
|||||||
|
|
||||||
async def async_setup(hass, config):
|
async def async_setup(hass, config):
|
||||||
"""Initialize the Home Assistant cloud."""
|
"""Initialize the Home Assistant cloud."""
|
||||||
|
|
||||||
# Process configs
|
# Process configs
|
||||||
if DOMAIN in config:
|
if DOMAIN in config:
|
||||||
kwargs = dict(config[DOMAIN])
|
kwargs = dict(config[DOMAIN])
|
||||||
@ -248,4 +249,7 @@ async def async_setup(hass, config):
|
|||||||
cloud.iot.register_on_connect(_on_connect)
|
cloud.iot.register_on_connect(_on_connect)
|
||||||
|
|
||||||
await http_api.async_setup(hass)
|
await http_api.async_setup(hass)
|
||||||
|
|
||||||
|
account_link.async_setup(hass)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
132
homeassistant/components/cloud/account_link.py
Normal file
132
homeassistant/components/cloud/account_link.py
Normal file
@ -0,0 +1,132 @@
|
|||||||
|
"""Account linking via the cloud."""
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from hass_nabucasa import account_link
|
||||||
|
|
||||||
|
from homeassistant.const import MAJOR_VERSION, MINOR_VERSION, PATCH_VERSION
|
||||||
|
from homeassistant.core import HomeAssistant, callback
|
||||||
|
from homeassistant.helpers import event, config_entry_oauth2_flow
|
||||||
|
|
||||||
|
from .const import DOMAIN
|
||||||
|
|
||||||
|
DATA_SERVICES = "cloud_account_link_services"
|
||||||
|
CACHE_TIMEOUT = 3600
|
||||||
|
PATCH_VERSION = int(PATCH_VERSION.split(".")[0])
|
||||||
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_setup(hass: HomeAssistant):
|
||||||
|
"""Set up cloud account link."""
|
||||||
|
config_entry_oauth2_flow.async_add_implementation_provider(
|
||||||
|
hass, DOMAIN, async_provide_implementation
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def async_provide_implementation(hass: HomeAssistant, domain: str):
|
||||||
|
"""Provide an implementation for a domain."""
|
||||||
|
services = await _get_services(hass)
|
||||||
|
|
||||||
|
for service in services:
|
||||||
|
if service["service"] == domain and _is_older(service["min_version"]):
|
||||||
|
return CloudOAuth2Implementation(hass, domain)
|
||||||
|
|
||||||
|
return
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def _is_older(version: str) -> bool:
|
||||||
|
"""Test if a version is older than the current HA version."""
|
||||||
|
version_parts = version.split(".")
|
||||||
|
|
||||||
|
if len(version_parts) != 3:
|
||||||
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
version_parts = [int(val) for val in version_parts]
|
||||||
|
except ValueError:
|
||||||
|
return False
|
||||||
|
|
||||||
|
cur_version_parts = [MAJOR_VERSION, MINOR_VERSION, PATCH_VERSION]
|
||||||
|
|
||||||
|
return version_parts <= cur_version_parts
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_services(hass):
|
||||||
|
"""Get the available services."""
|
||||||
|
services = hass.data.get(DATA_SERVICES)
|
||||||
|
|
||||||
|
if services is not None:
|
||||||
|
return services
|
||||||
|
|
||||||
|
services = await account_link.async_fetch_available_services(hass.data[DOMAIN])
|
||||||
|
|
||||||
|
hass.data[DATA_SERVICES] = services
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def clear_services(_now):
|
||||||
|
"""Clear services cache."""
|
||||||
|
hass.data.pop(DATA_SERVICES, None)
|
||||||
|
|
||||||
|
event.async_call_later(hass, CACHE_TIMEOUT, clear_services)
|
||||||
|
|
||||||
|
return services
|
||||||
|
|
||||||
|
|
||||||
|
class CloudOAuth2Implementation(config_entry_oauth2_flow.AbstractOAuth2Implementation):
|
||||||
|
"""Cloud implementation of the OAuth2 flow."""
|
||||||
|
|
||||||
|
def __init__(self, hass: HomeAssistant, service: str):
|
||||||
|
"""Initialize cloud OAuth2 implementation."""
|
||||||
|
self.hass = hass
|
||||||
|
self.service = service
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self) -> str:
|
||||||
|
"""Name of the implementation."""
|
||||||
|
return "Home Assistant Cloud"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def domain(self) -> str:
|
||||||
|
"""Domain that is providing the implementation."""
|
||||||
|
return DOMAIN
|
||||||
|
|
||||||
|
async def async_generate_authorize_url(self, flow_id: str) -> str:
|
||||||
|
"""Generate a url for the user to authorize."""
|
||||||
|
helper = account_link.AuthorizeAccountHelper(
|
||||||
|
self.hass.data[DOMAIN], self.service
|
||||||
|
)
|
||||||
|
authorize_url = await helper.async_get_authorize_url()
|
||||||
|
|
||||||
|
async def await_tokens():
|
||||||
|
"""Wait for tokens and pass them on when received."""
|
||||||
|
try:
|
||||||
|
tokens = await helper.async_get_tokens()
|
||||||
|
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
_LOGGER.info("Timeout fetching tokens for flow %s", flow_id)
|
||||||
|
except account_link.AccountLinkException as err:
|
||||||
|
_LOGGER.info(
|
||||||
|
"Failed to fetch tokens for flow %s: %s", flow_id, err.code
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await self.hass.config_entries.flow.async_configure(
|
||||||
|
flow_id=flow_id, user_input=tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
self.hass.async_create_task(await_tokens())
|
||||||
|
|
||||||
|
return authorize_url
|
||||||
|
|
||||||
|
async def async_resolve_external_data(self, external_data: Any) -> dict:
|
||||||
|
"""Resolve external data to tokens."""
|
||||||
|
# We already passed in tokens
|
||||||
|
return external_data
|
||||||
|
|
||||||
|
async def _async_refresh_token(self, token: dict) -> dict:
|
||||||
|
"""Refresh a token."""
|
||||||
|
return await account_link.async_fetch_access_token(
|
||||||
|
self.hass.data[DOMAIN], self.service, token["refresh_token"]
|
||||||
|
)
|
@ -37,6 +37,7 @@ CONF_REMOTE_API_URL = "remote_api_url"
|
|||||||
CONF_ACME_DIRECTORY_SERVER = "acme_directory_server"
|
CONF_ACME_DIRECTORY_SERVER = "acme_directory_server"
|
||||||
CONF_ALEXA_ACCESS_TOKEN_URL = "alexa_access_token_url"
|
CONF_ALEXA_ACCESS_TOKEN_URL = "alexa_access_token_url"
|
||||||
CONF_GOOGLE_ACTIONS_REPORT_STATE_URL = "google_actions_report_state_url"
|
CONF_GOOGLE_ACTIONS_REPORT_STATE_URL = "google_actions_report_state_url"
|
||||||
|
CONF_ACCOUNT_LINK_URL = "account_link_url"
|
||||||
|
|
||||||
MODE_DEV = "development"
|
MODE_DEV = "development"
|
||||||
MODE_PROD = "production"
|
MODE_PROD = "production"
|
||||||
|
@ -2,7 +2,7 @@
|
|||||||
"domain": "cloud",
|
"domain": "cloud",
|
||||||
"name": "Cloud",
|
"name": "Cloud",
|
||||||
"documentation": "https://www.home-assistant.io/integrations/cloud",
|
"documentation": "https://www.home-assistant.io/integrations/cloud",
|
||||||
"requirements": ["hass-nabucasa==0.22"],
|
"requirements": ["hass-nabucasa==0.23"],
|
||||||
"dependencies": ["http", "webhook"],
|
"dependencies": ["http", "webhook"],
|
||||||
"codeowners": ["@home-assistant/cloud"]
|
"codeowners": ["@home-assistant/cloud"]
|
||||||
}
|
}
|
||||||
|
@ -8,6 +8,11 @@
|
|||||||
"create_entry": {
|
"create_entry": {
|
||||||
"default": "Successfully authenticated with Somfy."
|
"default": "Successfully authenticated with Somfy."
|
||||||
},
|
},
|
||||||
|
"step": {
|
||||||
|
"pick_implementation": {
|
||||||
|
"title": "Pick Authentication Method"
|
||||||
|
}
|
||||||
|
},
|
||||||
"title": "Somfy"
|
"title": "Somfy"
|
||||||
}
|
}
|
||||||
}
|
}
|
@ -1,13 +1,18 @@
|
|||||||
{
|
{
|
||||||
"config": {
|
"config": {
|
||||||
"abort": {
|
"step": {
|
||||||
"already_setup": "You can only configure one Somfy account.",
|
"pick_implementation": {
|
||||||
"authorize_url_timeout": "Timeout generating authorize url.",
|
"title": "Pick Authentication Method"
|
||||||
"missing_configuration": "The Somfy component is not configured. Please follow the documentation."
|
}
|
||||||
},
|
},
|
||||||
"create_entry": {
|
"abort": {
|
||||||
"default": "Successfully authenticated with Somfy."
|
"already_setup": "You can only configure one Somfy account.",
|
||||||
},
|
"authorize_url_timeout": "Timeout generating authorize url.",
|
||||||
"title": "Somfy"
|
"missing_configuration": "The Somfy component is not configured. Please follow the documentation."
|
||||||
}
|
},
|
||||||
|
"create_entry": {
|
||||||
|
"default": "Successfully authenticated with Somfy."
|
||||||
|
},
|
||||||
|
"title": "Somfy"
|
||||||
|
}
|
||||||
}
|
}
|
@ -8,7 +8,7 @@ This module exists of the following parts:
|
|||||||
import asyncio
|
import asyncio
|
||||||
from abc import ABCMeta, ABC, abstractmethod
|
from abc import ABCMeta, ABC, abstractmethod
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional, Any, Dict, cast
|
from typing import Optional, Any, Dict, cast, Awaitable, Callable
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import async_timeout
|
import async_timeout
|
||||||
@ -28,6 +28,7 @@ from .aiohttp_client import async_get_clientsession
|
|||||||
DATA_JWT_SECRET = "oauth2_jwt_secret"
|
DATA_JWT_SECRET = "oauth2_jwt_secret"
|
||||||
DATA_VIEW_REGISTERED = "oauth2_view_reg"
|
DATA_VIEW_REGISTERED = "oauth2_view_reg"
|
||||||
DATA_IMPLEMENTATIONS = "oauth2_impl"
|
DATA_IMPLEMENTATIONS = "oauth2_impl"
|
||||||
|
DATA_PROVIDERS = "oauth2_providers"
|
||||||
AUTH_CALLBACK_PATH = "/auth/external/callback"
|
AUTH_CALLBACK_PATH = "/auth/external/callback"
|
||||||
|
|
||||||
|
|
||||||
@ -291,11 +292,23 @@ async def async_get_implementations(
|
|||||||
hass: HomeAssistant, domain: str
|
hass: HomeAssistant, domain: str
|
||||||
) -> Dict[str, AbstractOAuth2Implementation]:
|
) -> Dict[str, AbstractOAuth2Implementation]:
|
||||||
"""Return OAuth2 implementations for specified domain."""
|
"""Return OAuth2 implementations for specified domain."""
|
||||||
return cast(
|
registered = cast(
|
||||||
Dict[str, AbstractOAuth2Implementation],
|
Dict[str, AbstractOAuth2Implementation],
|
||||||
hass.data.setdefault(DATA_IMPLEMENTATIONS, {}).get(domain, {}),
|
hass.data.setdefault(DATA_IMPLEMENTATIONS, {}).get(domain, {}),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if DATA_PROVIDERS not in hass.data:
|
||||||
|
return registered
|
||||||
|
|
||||||
|
registered = dict(registered)
|
||||||
|
|
||||||
|
for provider_domain, get_impl in hass.data[DATA_PROVIDERS].items():
|
||||||
|
implementation = await get_impl(hass, domain)
|
||||||
|
if implementation is not None:
|
||||||
|
registered[provider_domain] = implementation
|
||||||
|
|
||||||
|
return registered
|
||||||
|
|
||||||
|
|
||||||
async def async_get_config_entry_implementation(
|
async def async_get_config_entry_implementation(
|
||||||
hass: HomeAssistant, config_entry: config_entries.ConfigEntry
|
hass: HomeAssistant, config_entry: config_entries.ConfigEntry
|
||||||
@ -310,6 +323,23 @@ async def async_get_config_entry_implementation(
|
|||||||
return implementation
|
return implementation
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_add_implementation_provider(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
provider_domain: str,
|
||||||
|
async_provide_implementation: Callable[
|
||||||
|
[HomeAssistant, str], Awaitable[Optional[AbstractOAuth2Implementation]]
|
||||||
|
],
|
||||||
|
) -> None:
|
||||||
|
"""Add an implementation provider.
|
||||||
|
|
||||||
|
If no implementation found, return None.
|
||||||
|
"""
|
||||||
|
hass.data.setdefault(DATA_PROVIDERS, {})[
|
||||||
|
provider_domain
|
||||||
|
] = async_provide_implementation
|
||||||
|
|
||||||
|
|
||||||
class OAuth2AuthorizeCallbackView(HomeAssistantView):
|
class OAuth2AuthorizeCallbackView(HomeAssistantView):
|
||||||
"""OAuth2 Authorization Callback View."""
|
"""OAuth2 Authorization Callback View."""
|
||||||
|
|
||||||
@ -355,9 +385,14 @@ class OAuth2Session:
|
|||||||
self.config_entry = config_entry
|
self.config_entry = config_entry
|
||||||
self.implementation = implementation
|
self.implementation = implementation
|
||||||
|
|
||||||
|
@property
|
||||||
|
def token(self) -> dict:
|
||||||
|
"""Return the current token."""
|
||||||
|
return cast(dict, self.config_entry.data["token"])
|
||||||
|
|
||||||
async def async_ensure_token_valid(self) -> None:
|
async def async_ensure_token_valid(self) -> None:
|
||||||
"""Ensure that the current token is valid."""
|
"""Ensure that the current token is valid."""
|
||||||
token = self.config_entry.data["token"]
|
token = self.token
|
||||||
|
|
||||||
if token["expires_at"] > time.time():
|
if token["expires_at"] > time.time():
|
||||||
return
|
return
|
||||||
|
@ -10,7 +10,7 @@ certifi>=2019.9.11
|
|||||||
contextvars==2.4;python_version<"3.7"
|
contextvars==2.4;python_version<"3.7"
|
||||||
cryptography==2.8
|
cryptography==2.8
|
||||||
distro==1.4.0
|
distro==1.4.0
|
||||||
hass-nabucasa==0.22
|
hass-nabucasa==0.23
|
||||||
home-assistant-frontend==20191025.0
|
home-assistant-frontend==20191025.0
|
||||||
importlib-metadata==0.23
|
importlib-metadata==0.23
|
||||||
jinja2>=2.10.1
|
jinja2>=2.10.1
|
||||||
|
@ -616,7 +616,7 @@ habitipy==0.2.0
|
|||||||
hangups==0.4.9
|
hangups==0.4.9
|
||||||
|
|
||||||
# homeassistant.components.cloud
|
# homeassistant.components.cloud
|
||||||
hass-nabucasa==0.22
|
hass-nabucasa==0.23
|
||||||
|
|
||||||
# homeassistant.components.mqtt
|
# homeassistant.components.mqtt
|
||||||
hbmqtt==0.9.5
|
hbmqtt==0.9.5
|
||||||
|
@ -225,7 +225,7 @@ ha-ffmpeg==2.0
|
|||||||
hangups==0.4.9
|
hangups==0.4.9
|
||||||
|
|
||||||
# homeassistant.components.cloud
|
# homeassistant.components.cloud
|
||||||
hass-nabucasa==0.22
|
hass-nabucasa==0.23
|
||||||
|
|
||||||
# homeassistant.components.mqtt
|
# homeassistant.components.mqtt
|
||||||
hbmqtt==0.9.5
|
hbmqtt==0.9.5
|
||||||
|
160
tests/components/cloud/test_account_link.py
Normal file
160
tests/components/cloud/test_account_link.py
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
"""Test account link services."""
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from time import time
|
||||||
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from homeassistant import data_entry_flow, config_entries
|
||||||
|
from homeassistant.helpers import config_entry_oauth2_flow
|
||||||
|
from homeassistant.components.cloud import account_link
|
||||||
|
from homeassistant.util.dt import utcnow
|
||||||
|
from tests.common import mock_coro, async_fire_time_changed, mock_platform
|
||||||
|
|
||||||
|
|
||||||
|
TEST_DOMAIN = "oauth2_test"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def flow_handler(hass):
|
||||||
|
"""Return a registered config flow."""
|
||||||
|
|
||||||
|
mock_platform(hass, f"{TEST_DOMAIN}.config_flow")
|
||||||
|
|
||||||
|
class TestFlowHandler(config_entry_oauth2_flow.AbstractOAuth2FlowHandler):
|
||||||
|
"""Test flow handler."""
|
||||||
|
|
||||||
|
DOMAIN = TEST_DOMAIN
|
||||||
|
|
||||||
|
@property
|
||||||
|
def logger(self) -> logging.Logger:
|
||||||
|
"""Return logger."""
|
||||||
|
return logging.getLogger(__name__)
|
||||||
|
|
||||||
|
with patch.dict(config_entries.HANDLERS, {TEST_DOMAIN: TestFlowHandler}):
|
||||||
|
yield TestFlowHandler
|
||||||
|
|
||||||
|
|
||||||
|
async def test_setup_provide_implementation(hass):
|
||||||
|
"""Test that we provide implementations."""
|
||||||
|
account_link.async_setup(hass)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.cloud.account_link._get_services",
|
||||||
|
side_effect=lambda _: mock_coro(
|
||||||
|
[
|
||||||
|
{"service": "test", "min_version": "0.1.0"},
|
||||||
|
{"service": "too_new", "min_version": "100.0.0"},
|
||||||
|
]
|
||||||
|
),
|
||||||
|
):
|
||||||
|
assert (
|
||||||
|
await config_entry_oauth2_flow.async_get_implementations(
|
||||||
|
hass, "non_existing"
|
||||||
|
)
|
||||||
|
== {}
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
await config_entry_oauth2_flow.async_get_implementations(hass, "too_new")
|
||||||
|
== {}
|
||||||
|
)
|
||||||
|
implementations = await config_entry_oauth2_flow.async_get_implementations(
|
||||||
|
hass, "test"
|
||||||
|
)
|
||||||
|
|
||||||
|
assert "cloud" in implementations
|
||||||
|
assert implementations["cloud"].domain == "cloud"
|
||||||
|
assert implementations["cloud"].service == "test"
|
||||||
|
assert implementations["cloud"].hass is hass
|
||||||
|
|
||||||
|
|
||||||
|
async def test_get_services_cached(hass):
|
||||||
|
"""Test that we cache services."""
|
||||||
|
hass.data["cloud"] = None
|
||||||
|
|
||||||
|
services = 1
|
||||||
|
|
||||||
|
with patch.object(account_link, "CACHE_TIMEOUT", 0), patch(
|
||||||
|
"hass_nabucasa.account_link.async_fetch_available_services",
|
||||||
|
side_effect=lambda _: mock_coro(services),
|
||||||
|
) as mock_fetch:
|
||||||
|
assert await account_link._get_services(hass) == 1
|
||||||
|
|
||||||
|
services = 2
|
||||||
|
|
||||||
|
assert len(mock_fetch.mock_calls) == 1
|
||||||
|
assert await account_link._get_services(hass) == 1
|
||||||
|
|
||||||
|
services = 3
|
||||||
|
hass.data.pop(account_link.DATA_SERVICES)
|
||||||
|
assert await account_link._get_services(hass) == 3
|
||||||
|
|
||||||
|
services = 4
|
||||||
|
async_fire_time_changed(hass, utcnow())
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
# Check cache purged
|
||||||
|
assert await account_link._get_services(hass) == 4
|
||||||
|
|
||||||
|
|
||||||
|
async def test_implementation(hass, flow_handler):
|
||||||
|
"""Test Cloud OAuth2 implementation."""
|
||||||
|
hass.data["cloud"] = None
|
||||||
|
|
||||||
|
impl = account_link.CloudOAuth2Implementation(hass, "test")
|
||||||
|
assert impl.name == "Home Assistant Cloud"
|
||||||
|
assert impl.domain == "cloud"
|
||||||
|
|
||||||
|
flow_handler.async_register_implementation(hass, impl)
|
||||||
|
|
||||||
|
flow_finished = asyncio.Future()
|
||||||
|
|
||||||
|
helper = Mock(
|
||||||
|
async_get_authorize_url=Mock(return_value=mock_coro("http://example.com/auth")),
|
||||||
|
async_get_tokens=Mock(return_value=flow_finished),
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"hass_nabucasa.account_link.AuthorizeAccountHelper", return_value=helper
|
||||||
|
):
|
||||||
|
result = await hass.config_entries.flow.async_init(
|
||||||
|
TEST_DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result["type"] == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP
|
||||||
|
assert result["url"] == "http://example.com/auth"
|
||||||
|
|
||||||
|
flow_finished.set_result(
|
||||||
|
{
|
||||||
|
"refresh_token": "mock-refresh",
|
||||||
|
"access_token": "mock-access",
|
||||||
|
"expires_in": 10,
|
||||||
|
"token_type": "bearer",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
# Flow finished!
|
||||||
|
result = await hass.config_entries.flow.async_configure(result["flow_id"])
|
||||||
|
|
||||||
|
assert result["data"]["auth_implementation"] == "cloud"
|
||||||
|
|
||||||
|
expires_at = result["data"]["token"].pop("expires_at")
|
||||||
|
assert round(expires_at - time()) == 10
|
||||||
|
|
||||||
|
assert result["data"]["token"] == {
|
||||||
|
"refresh_token": "mock-refresh",
|
||||||
|
"access_token": "mock-access",
|
||||||
|
"token_type": "bearer",
|
||||||
|
"expires_in": 10,
|
||||||
|
}
|
||||||
|
|
||||||
|
entry = hass.config_entries.async_entries(TEST_DOMAIN)[0]
|
||||||
|
|
||||||
|
assert (
|
||||||
|
await config_entry_oauth2_flow.async_get_config_entry_implementation(
|
||||||
|
hass, entry
|
||||||
|
)
|
||||||
|
is impl
|
||||||
|
)
|
@ -264,3 +264,45 @@ async def test_oauth_session(hass, flow_handler, local_impl, aioclient_mock):
|
|||||||
assert config_entry.data["token"]["expires_in"] == 100
|
assert config_entry.data["token"]["expires_in"] == 100
|
||||||
assert config_entry.data["token"]["random_other_data"] == "should_stay"
|
assert config_entry.data["token"]["random_other_data"] == "should_stay"
|
||||||
assert round(config_entry.data["token"]["expires_at"] - now) == 100
|
assert round(config_entry.data["token"]["expires_at"] - now) == 100
|
||||||
|
|
||||||
|
|
||||||
|
async def test_implementation_provider(hass, local_impl):
|
||||||
|
"""Test providing an implementation provider."""
|
||||||
|
assert (
|
||||||
|
await config_entry_oauth2_flow.async_get_implementations(hass, TEST_DOMAIN)
|
||||||
|
== {}
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_domain_with_impl = "some_domain"
|
||||||
|
|
||||||
|
config_entry_oauth2_flow.async_register_implementation(
|
||||||
|
hass, mock_domain_with_impl, local_impl
|
||||||
|
)
|
||||||
|
|
||||||
|
assert await config_entry_oauth2_flow.async_get_implementations(
|
||||||
|
hass, mock_domain_with_impl
|
||||||
|
) == {TEST_DOMAIN: local_impl}
|
||||||
|
|
||||||
|
provider_source = {}
|
||||||
|
|
||||||
|
async def async_provide_implementation(hass, domain):
|
||||||
|
"""Mock implementation provider."""
|
||||||
|
return provider_source.get(domain)
|
||||||
|
|
||||||
|
config_entry_oauth2_flow.async_add_implementation_provider(
|
||||||
|
hass, "cloud", async_provide_implementation
|
||||||
|
)
|
||||||
|
|
||||||
|
assert await config_entry_oauth2_flow.async_get_implementations(
|
||||||
|
hass, mock_domain_with_impl
|
||||||
|
) == {TEST_DOMAIN: local_impl}
|
||||||
|
|
||||||
|
provider_source[
|
||||||
|
mock_domain_with_impl
|
||||||
|
] = config_entry_oauth2_flow.LocalOAuth2Implementation(
|
||||||
|
hass, "cloud", CLIENT_ID, CLIENT_SECRET, AUTHORIZE_URL, TOKEN_URL
|
||||||
|
)
|
||||||
|
|
||||||
|
assert await config_entry_oauth2_flow.async_get_implementations(
|
||||||
|
hass, mock_domain_with_impl
|
||||||
|
) == {TEST_DOMAIN: local_impl, "cloud": provider_source[mock_domain_with_impl]}
|
||||||
|
Loading…
x
Reference in New Issue
Block a user