mirror of
https://github.com/home-assistant/core.git
synced 2025-04-23 00:37:53 +00:00
Add application credentials platform (#69148)
* Initial developer credentials scaffolding - Support websocket list/add/delete - Add developer credentials protocol from yaml config - Handle OAuth credential registration and de-registration - Tests for websocket and integration based registration * Fix pydoc text * Remove translations and update owners * Update homeassistant/components/developer_credentials/__init__.py Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io> * Update homeassistant/components/developer_credentials/__init__.py Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io> * Remove _async_get_developer_credential * Rename to application credentials platform * Fix race condition and add import support * Increase code coverage (92%) * Increase test coverage 93% * Increase test coverage (94%) * Increase test coverage (97%) * Increase test covearge (98%) * Increase test coverage (99%) * Increase test coverage (100%) * Remove http router frozen comment * Remove auth domain override on import * Remove debug statement * Don't import the same client id multiple times * Add auth dependency for local oauth implementation * Revert older oauth2 changes from merge * Update homeassistant/components/application_credentials/__init__.py Co-authored-by: Martin Hjelmare <marhje52@gmail.com> * Move config credential import to its own fixture * Override the mock_application_credentials_integration fixture instead per test * Update application credentials * Add dictionary typing * Use f-strings as per feedback * Add additional structure needed for an MVP application credential Add additional structure needed for an MVP, including a target component Xbox * Add websocket to list supported integrations for frontend selector * Application credentials config * Import xbox credentials * Remove unnecessary async calls * Update script/hassfest/application_credentials.py Co-authored-by: Martin Hjelmare <marhje52@gmail.com> * Update script/hassfest/application_credentials.py Co-authored-by: Martin Hjelmare <marhje52@gmail.com> * Update script/hassfest/application_credentials.py Co-authored-by: Martin Hjelmare <marhje52@gmail.com> * Update script/hassfest/application_credentials.py Co-authored-by: Martin Hjelmare <marhje52@gmail.com> * Import credentials with a fixed auth domain Resolve an issue with compatibility of exisiting config entries when importing client credentials Co-authored-by: Paulus Schoutsen <paulus@home-assistant.io> Co-authored-by: Martin Hjelmare <marhje52@gmail.com>
This commit is contained in:
parent
ae8604d429
commit
00b5d30e24
@ -75,6 +75,8 @@ build.json @home-assistant/supervisor
|
||||
/tests/components/api/ @home-assistant/core
|
||||
/homeassistant/components/apple_tv/ @postlund
|
||||
/tests/components/apple_tv/ @postlund
|
||||
/homeassistant/components/application_credentials/ @home-assistant/core
|
||||
/tests/components/application_credentials/ @home-assistant/core
|
||||
/homeassistant/components/apprise/ @caronc
|
||||
/tests/components/apprise/ @caronc
|
||||
/homeassistant/components/aprs/ @PhilRW
|
||||
|
242
homeassistant/components/application_credentials/__init__.py
Normal file
242
homeassistant/components/application_credentials/__init__.py
Normal file
@ -0,0 +1,242 @@
|
||||
"""The Application Credentials integration.
|
||||
|
||||
This integration provides APIs for managing local OAuth credentials on behalf
|
||||
of other integrations. Integrations register an authorization server, and then
|
||||
the APIs are used to add one or more client credentials. Integrations may also
|
||||
provide credentials from yaml for backwards compatibility.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
import logging
|
||||
from typing import Any, Protocol
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import websocket_api
|
||||
from homeassistant.components.websocket_api.connection import ActiveConnection
|
||||
from homeassistant.const import CONF_CLIENT_ID, CONF_CLIENT_SECRET, CONF_DOMAIN, CONF_ID
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.generated.application_credentials import APPLICATION_CREDENTIALS
|
||||
from homeassistant.helpers import collection, config_entry_oauth2_flow
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.storage import Store
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.loader import IntegrationNotFound, async_get_integration
|
||||
from homeassistant.util import slugify
|
||||
|
||||
__all__ = ["ClientCredential", "AuthorizationServer", "async_import_client_credential"]
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
DOMAIN = "application_credentials"
|
||||
|
||||
STORAGE_KEY = DOMAIN
|
||||
STORAGE_VERSION = 1
|
||||
DATA_STORAGE = "storage"
|
||||
CONF_AUTH_DOMAIN = "auth_domain"
|
||||
|
||||
CREATE_FIELDS = {
|
||||
vol.Required(CONF_DOMAIN): cv.string,
|
||||
vol.Required(CONF_CLIENT_ID): cv.string,
|
||||
vol.Required(CONF_CLIENT_SECRET): cv.string,
|
||||
vol.Optional(CONF_AUTH_DOMAIN): cv.string,
|
||||
}
|
||||
UPDATE_FIELDS: dict = {} # Not supported
|
||||
|
||||
|
||||
@dataclass
|
||||
class ClientCredential:
|
||||
"""Represent an OAuth client credential."""
|
||||
|
||||
client_id: str
|
||||
client_secret: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthorizationServer:
|
||||
"""Represent an OAuth2 Authorization Server."""
|
||||
|
||||
authorize_url: str
|
||||
token_url: str
|
||||
|
||||
|
||||
class ApplicationCredentialsStorageCollection(collection.StorageCollection):
|
||||
"""Application credential collection stored in storage."""
|
||||
|
||||
CREATE_SCHEMA = vol.Schema(CREATE_FIELDS)
|
||||
|
||||
async def _process_create_data(self, data: dict[str, str]) -> dict[str, str]:
|
||||
"""Validate the config is valid."""
|
||||
result = self.CREATE_SCHEMA(data)
|
||||
domain = result[CONF_DOMAIN]
|
||||
if not await _get_platform(self.hass, domain):
|
||||
raise ValueError(f"No application_credentials platform for {domain}")
|
||||
return result
|
||||
|
||||
@callback
|
||||
def _get_suggested_id(self, info: dict[str, str]) -> str:
|
||||
"""Suggest an ID based on the config."""
|
||||
return f"{info[CONF_DOMAIN]}.{info[CONF_CLIENT_ID]}"
|
||||
|
||||
async def _update_data(
|
||||
self, data: dict[str, str], update_data: dict[str, str]
|
||||
) -> dict[str, str]:
|
||||
"""Return a new updated data object."""
|
||||
raise ValueError("Updates not supported")
|
||||
|
||||
async def async_delete_item(self, item_id: str) -> None:
|
||||
"""Delete item, verifying credential is not in use."""
|
||||
if item_id not in self.data:
|
||||
raise collection.ItemNotFound(item_id)
|
||||
|
||||
# Cannot delete a credential currently in use by a ConfigEntry
|
||||
current = self.data[item_id]
|
||||
entries = self.hass.config_entries.async_entries(current[CONF_DOMAIN])
|
||||
for entry in entries:
|
||||
if entry.data.get("auth_implementation") == item_id:
|
||||
raise ValueError("Cannot delete credential in use by an integration")
|
||||
|
||||
await super().async_delete_item(item_id)
|
||||
|
||||
async def async_import_item(self, info: dict[str, str]) -> None:
|
||||
"""Import an yaml credential if it does not already exist."""
|
||||
suggested_id = self._get_suggested_id(info)
|
||||
if self.id_manager.has_id(slugify(suggested_id)):
|
||||
return
|
||||
await self.async_create_item(info)
|
||||
|
||||
def async_client_credentials(self, domain: str) -> dict[str, ClientCredential]:
|
||||
"""Return ClientCredentials in storage for the specified domain."""
|
||||
credentials = {}
|
||||
for item in self.async_items():
|
||||
if item[CONF_DOMAIN] != domain:
|
||||
continue
|
||||
auth_domain = (
|
||||
item[CONF_AUTH_DOMAIN] if CONF_AUTH_DOMAIN in item else item[CONF_ID]
|
||||
)
|
||||
credentials[auth_domain] = ClientCredential(
|
||||
item[CONF_CLIENT_ID], item[CONF_CLIENT_SECRET]
|
||||
)
|
||||
return credentials
|
||||
|
||||
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up Application Credentials."""
|
||||
hass.data[DOMAIN] = {}
|
||||
|
||||
id_manager = collection.IDManager()
|
||||
storage_collection = ApplicationCredentialsStorageCollection(
|
||||
Store(hass, STORAGE_VERSION, STORAGE_KEY),
|
||||
logging.getLogger(f"{__name__}.storage_collection"),
|
||||
id_manager,
|
||||
)
|
||||
await storage_collection.async_load()
|
||||
hass.data[DOMAIN][DATA_STORAGE] = storage_collection
|
||||
|
||||
collection.StorageCollectionWebsocket(
|
||||
storage_collection, DOMAIN, DOMAIN, CREATE_FIELDS, UPDATE_FIELDS
|
||||
).async_setup(hass)
|
||||
|
||||
websocket_api.async_register_command(hass, handle_integration_list)
|
||||
|
||||
config_entry_oauth2_flow.async_add_implementation_provider(
|
||||
hass, DOMAIN, _async_provide_implementation
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
async def async_import_client_credential(
|
||||
hass: HomeAssistant, domain: str, credential: ClientCredential
|
||||
) -> None:
|
||||
"""Import an existing credential from configuration.yaml."""
|
||||
if DOMAIN not in hass.data:
|
||||
raise ValueError("Integration 'application_credentials' not setup")
|
||||
storage_collection = hass.data[DOMAIN][DATA_STORAGE]
|
||||
item = {
|
||||
CONF_DOMAIN: domain,
|
||||
CONF_CLIENT_ID: credential.client_id,
|
||||
CONF_CLIENT_SECRET: credential.client_secret,
|
||||
CONF_AUTH_DOMAIN: domain,
|
||||
}
|
||||
await storage_collection.async_import_item(item)
|
||||
|
||||
|
||||
class AuthImplementation(config_entry_oauth2_flow.LocalOAuth2Implementation):
|
||||
"""Application Credentials local oauth2 implementation."""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Name of the implementation."""
|
||||
return self.client_id
|
||||
|
||||
|
||||
async def _async_provide_implementation(
|
||||
hass: HomeAssistant, domain: str
|
||||
) -> list[config_entry_oauth2_flow.AbstractOAuth2Implementation]:
|
||||
"""Return registered OAuth implementations."""
|
||||
|
||||
platform = await _get_platform(hass, domain)
|
||||
if not platform:
|
||||
return []
|
||||
|
||||
authorization_server = await platform.async_get_authorization_server(hass)
|
||||
storage_collection = hass.data[DOMAIN][DATA_STORAGE]
|
||||
credentials = storage_collection.async_client_credentials(domain)
|
||||
return [
|
||||
AuthImplementation(
|
||||
hass,
|
||||
auth_domain,
|
||||
credential.client_id,
|
||||
credential.client_secret,
|
||||
authorization_server.authorize_url,
|
||||
authorization_server.token_url,
|
||||
)
|
||||
for auth_domain, credential in credentials.items()
|
||||
]
|
||||
|
||||
|
||||
class ApplicationCredentialsProtocol(Protocol):
|
||||
"""Define the format that application_credentials platforms can have."""
|
||||
|
||||
async def async_get_authorization_server(
|
||||
self, hass: HomeAssistant
|
||||
) -> AuthorizationServer:
|
||||
"""Return authorization server."""
|
||||
|
||||
|
||||
async def _get_platform(
|
||||
hass: HomeAssistant, integration_domain: str
|
||||
) -> ApplicationCredentialsProtocol | None:
|
||||
"""Register an application_credentials platform."""
|
||||
try:
|
||||
integration = await async_get_integration(hass, integration_domain)
|
||||
except IntegrationNotFound as err:
|
||||
_LOGGER.debug("Integration '%s' does not exist: %s", integration_domain, err)
|
||||
return None
|
||||
try:
|
||||
platform = integration.get_platform("application_credentials")
|
||||
except ImportError as err:
|
||||
_LOGGER.debug(
|
||||
"Integration '%s' does not provide application_credentials: %s",
|
||||
integration_domain,
|
||||
err,
|
||||
)
|
||||
return None
|
||||
if not hasattr(platform, "async_get_authorization_server"):
|
||||
raise ValueError(
|
||||
f"Integration '{integration_domain}' platform application_credentials did not implement 'async_get_authorization_server'"
|
||||
)
|
||||
return platform
|
||||
|
||||
|
||||
@websocket_api.websocket_command(
|
||||
{vol.Required("type"): "application_credentials/config"}
|
||||
)
|
||||
@callback
|
||||
def handle_integration_list(
|
||||
hass: HomeAssistant, connection: ActiveConnection, msg: dict[str, Any]
|
||||
) -> None:
|
||||
"""Handle integrations command."""
|
||||
connection.send_result(msg["id"], {"domains": APPLICATION_CREDENTIALS})
|
@ -0,0 +1,9 @@
|
||||
{
|
||||
"domain": "application_credentials",
|
||||
"name": "Application Credentials",
|
||||
"config_flow": false,
|
||||
"documentation": "https://www.home-assistant.io/integrations/application_credentials",
|
||||
"dependencies": ["auth", "websocket_api"],
|
||||
"codeowners": ["@home-assistant/core"],
|
||||
"quality_scale": "internal"
|
||||
}
|
@ -0,0 +1,3 @@
|
||||
{
|
||||
"title": "Application Credentials"
|
||||
}
|
@ -34,9 +34,9 @@ async def async_provide_implementation(hass: HomeAssistant, domain: str):
|
||||
|
||||
for service in services:
|
||||
if service["service"] == domain and CURRENT_VERSION >= service["min_version"]:
|
||||
return CloudOAuth2Implementation(hass, domain)
|
||||
return [CloudOAuth2Implementation(hass, domain)]
|
||||
|
||||
return
|
||||
return []
|
||||
|
||||
|
||||
async def _get_services(hass):
|
||||
|
@ -3,6 +3,7 @@
|
||||
"name": "Default Config",
|
||||
"documentation": "https://www.home-assistant.io/integrations/default_config",
|
||||
"dependencies": [
|
||||
"application_credentials",
|
||||
"automation",
|
||||
"cloud",
|
||||
"counter",
|
||||
|
@ -20,6 +20,7 @@ from xbox.webapi.api.provider.smartglass.models import (
|
||||
SmartglassConsoleStatus,
|
||||
)
|
||||
|
||||
from homeassistant.components import application_credentials
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import CONF_CLIENT_ID, CONF_CLIENT_SECRET, Platform
|
||||
from homeassistant.core import HomeAssistant
|
||||
@ -31,8 +32,8 @@ from homeassistant.helpers import (
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator
|
||||
|
||||
from . import api, config_flow
|
||||
from .const import DOMAIN, OAUTH2_AUTHORIZE, OAUTH2_TOKEN
|
||||
from . import api
|
||||
from .const import DOMAIN
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@ -63,15 +64,11 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
if DOMAIN not in config:
|
||||
return True
|
||||
|
||||
config_flow.OAuth2FlowHandler.async_register_implementation(
|
||||
await application_credentials.async_import_client_credential(
|
||||
hass,
|
||||
config_entry_oauth2_flow.LocalOAuth2Implementation(
|
||||
hass,
|
||||
DOMAIN,
|
||||
config[DOMAIN][CONF_CLIENT_ID],
|
||||
config[DOMAIN][CONF_CLIENT_SECRET],
|
||||
OAUTH2_AUTHORIZE,
|
||||
OAUTH2_TOKEN,
|
||||
DOMAIN,
|
||||
application_credentials.ClientCredential(
|
||||
config[DOMAIN][CONF_CLIENT_ID], config[DOMAIN][CONF_CLIENT_SECRET]
|
||||
),
|
||||
)
|
||||
|
||||
|
14
homeassistant/components/xbox/application_credentials.py
Normal file
14
homeassistant/components/xbox/application_credentials.py
Normal file
@ -0,0 +1,14 @@
|
||||
"""Application credentials platform for xbox."""
|
||||
|
||||
from homeassistant.components.application_credentials import AuthorizationServer
|
||||
from homeassistant.core import HomeAssistant
|
||||
|
||||
from .const import OAUTH2_AUTHORIZE, OAUTH2_TOKEN
|
||||
|
||||
|
||||
async def async_get_authorization_server(hass: HomeAssistant) -> AuthorizationServer:
|
||||
"""Return authorization server."""
|
||||
return AuthorizationServer(
|
||||
authorize_url=OAUTH2_AUTHORIZE,
|
||||
token_url=OAUTH2_TOKEN,
|
||||
)
|
@ -4,7 +4,7 @@
|
||||
"config_flow": true,
|
||||
"documentation": "https://www.home-assistant.io/integrations/xbox",
|
||||
"requirements": ["xbox-webapi==2.0.11"],
|
||||
"dependencies": ["auth"],
|
||||
"dependencies": ["auth", "application_credentials"],
|
||||
"codeowners": ["@hunterjm"],
|
||||
"iot_class": "cloud_polling"
|
||||
}
|
||||
|
10
homeassistant/generated/application_credentials.py
Normal file
10
homeassistant/generated/application_credentials.py
Normal file
@ -0,0 +1,10 @@
|
||||
"""Automatically generated by hassfest.
|
||||
|
||||
To update, run python3 -m script.hassfest
|
||||
"""
|
||||
|
||||
# fmt: off
|
||||
|
||||
APPLICATION_CREDENTIALS = [
|
||||
"xbox"
|
||||
]
|
@ -347,10 +347,9 @@ async def async_get_implementations(
|
||||
return registered
|
||||
|
||||
registered = dict(registered)
|
||||
|
||||
for provider_domain, get_impl in hass.data[DATA_PROVIDERS].items():
|
||||
if (implementation := await get_impl(hass, domain)) is not None:
|
||||
registered[provider_domain] = implementation
|
||||
for get_impl in list(hass.data[DATA_PROVIDERS].values()):
|
||||
for impl in await get_impl(hass, domain):
|
||||
registered[impl.domain] = impl
|
||||
|
||||
return registered
|
||||
|
||||
@ -373,7 +372,7 @@ def async_add_implementation_provider(
|
||||
hass: HomeAssistant,
|
||||
provider_domain: str,
|
||||
async_provide_implementation: Callable[
|
||||
[HomeAssistant, str], Awaitable[AbstractOAuth2Implementation | None]
|
||||
[HomeAssistant, str], Awaitable[list[AbstractOAuth2Implementation]]
|
||||
],
|
||||
) -> None:
|
||||
"""Add an implementation provider.
|
||||
|
@ -5,6 +5,7 @@ import sys
|
||||
from time import monotonic
|
||||
|
||||
from . import (
|
||||
application_credentials,
|
||||
codeowners,
|
||||
config_flow,
|
||||
coverage,
|
||||
@ -25,6 +26,7 @@ from . import (
|
||||
from .model import Config, Integration
|
||||
|
||||
INTEGRATION_PLUGINS = [
|
||||
application_credentials,
|
||||
codeowners,
|
||||
config_flow,
|
||||
dependencies,
|
||||
|
63
script/hassfest/application_credentials.py
Normal file
63
script/hassfest/application_credentials.py
Normal file
@ -0,0 +1,63 @@
|
||||
"""Generate application_credentials data."""
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
from .model import Config, Integration
|
||||
|
||||
BASE = """
|
||||
\"\"\"Automatically generated by hassfest.
|
||||
|
||||
To update, run python3 -m script.hassfest
|
||||
\"\"\"
|
||||
|
||||
# fmt: off
|
||||
|
||||
APPLICATION_CREDENTIALS = {}
|
||||
""".strip()
|
||||
|
||||
|
||||
def generate_and_validate(integrations: dict[str, Integration], config: Config) -> str:
|
||||
"""Validate and generate config flow data."""
|
||||
|
||||
match_list = []
|
||||
|
||||
for domain in sorted(integrations):
|
||||
integration = integrations[domain]
|
||||
application_credentials_file = integration.path / "application_credentials.py"
|
||||
if not application_credentials_file.is_file():
|
||||
continue
|
||||
|
||||
match_list.append(domain)
|
||||
|
||||
return BASE.format(json.dumps(match_list, indent=4))
|
||||
|
||||
|
||||
def validate(integrations: dict[str, Integration], config: Config) -> None:
|
||||
"""Validate application_credentials data."""
|
||||
application_credentials_path = (
|
||||
config.root / "homeassistant/generated/application_credentials.py"
|
||||
)
|
||||
config.cache["application_credentials"] = content = generate_and_validate(
|
||||
integrations, config
|
||||
)
|
||||
|
||||
if config.specific_integrations:
|
||||
return
|
||||
|
||||
if application_credentials_path.read_text(encoding="utf-8").strip() != content:
|
||||
config.add_error(
|
||||
"application_credentials",
|
||||
"File application_credentials.py is not up to date. Run python3 -m script.hassfest",
|
||||
fixable=True,
|
||||
)
|
||||
|
||||
|
||||
def generate(integrations: dict[str, Integration], config: Config):
|
||||
"""Generate application_credentials data."""
|
||||
application_credentials_path = (
|
||||
config.root / "homeassistant/generated/application_credentials.py"
|
||||
)
|
||||
application_credentials_path.write_text(
|
||||
f"{config.cache['application_credentials']}\n", encoding="utf-8"
|
||||
)
|
@ -36,6 +36,7 @@ SUPPORTED_IOT_CLASSES = [
|
||||
NO_IOT_CLASS = [
|
||||
*{platform.value for platform in Platform},
|
||||
"api",
|
||||
"application_credentials",
|
||||
"auth",
|
||||
"automation",
|
||||
"blueprint",
|
||||
|
1
tests/components/application_credentials/__init__.py
Normal file
1
tests/components/application_credentials/__init__.py
Normal file
@ -0,0 +1 @@
|
||||
"""Tests for the Application Credentials integration."""
|
623
tests/components/application_credentials/test_init.py
Normal file
623
tests/components/application_credentials/test_init.py
Normal file
@ -0,0 +1,623 @@
|
||||
"""Test the Developer Credentials integration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Generator
|
||||
import logging
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
from aiohttp import ClientWebSocketResponse
|
||||
import pytest
|
||||
|
||||
from homeassistant import config_entries, data_entry_flow
|
||||
from homeassistant.components.application_credentials import (
|
||||
CONF_AUTH_DOMAIN,
|
||||
DOMAIN,
|
||||
AuthorizationServer,
|
||||
ClientCredential,
|
||||
async_import_client_credential,
|
||||
)
|
||||
from homeassistant.const import CONF_CLIENT_ID, CONF_CLIENT_SECRET, CONF_DOMAIN
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import config_entry_oauth2_flow
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import mock_platform
|
||||
|
||||
CLIENT_ID = "some-client-id"
|
||||
CLIENT_SECRET = "some-client-secret"
|
||||
DEVELOPER_CREDENTIAL = ClientCredential(CLIENT_ID, CLIENT_SECRET)
|
||||
ID = "fake_integration_some_client_id"
|
||||
AUTHORIZE_URL = "https://example.com/auth"
|
||||
TOKEN_URL = "https://example.com/oauth2/v4/token"
|
||||
REFRESH_TOKEN = "mock-refresh-token"
|
||||
ACCESS_TOKEN = "mock-access-token"
|
||||
|
||||
TEST_DOMAIN = "fake_integration"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def authorization_server() -> AuthorizationServer:
|
||||
"""Fixture AuthorizationServer for mock application_credentials integration."""
|
||||
return AuthorizationServer(AUTHORIZE_URL, TOKEN_URL)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def config_credential() -> ClientCredential | None:
|
||||
"""Fixture ClientCredential for mock application_credentials integration."""
|
||||
return None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def import_config_credential(
|
||||
hass: HomeAssistant, config_credential: ClientCredential
|
||||
) -> None:
|
||||
"""Fixture to import the yaml based credential."""
|
||||
await async_import_client_credential(hass, TEST_DOMAIN, config_credential)
|
||||
|
||||
|
||||
async def setup_application_credentials_integration(
|
||||
hass: HomeAssistant,
|
||||
domain: str,
|
||||
authorization_server: AuthorizationServer,
|
||||
) -> None:
|
||||
"""Set up a fake application_credentials integration."""
|
||||
hass.config.components.add(domain)
|
||||
mock_platform(
|
||||
hass,
|
||||
f"{domain}.application_credentials",
|
||||
Mock(
|
||||
async_get_authorization_server=AsyncMock(return_value=authorization_server),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
async def mock_application_credentials_integration(
|
||||
hass: HomeAssistant,
|
||||
authorization_server: AuthorizationServer,
|
||||
):
|
||||
"""Mock a application_credentials integration."""
|
||||
assert await async_setup_component(hass, "application_credentials", {})
|
||||
await setup_application_credentials_integration(
|
||||
hass, TEST_DOMAIN, authorization_server
|
||||
)
|
||||
|
||||
|
||||
class FakeConfigFlow(config_entry_oauth2_flow.AbstractOAuth2FlowHandler, domain=DOMAIN):
|
||||
"""Config flow used during tests."""
|
||||
|
||||
DOMAIN = TEST_DOMAIN
|
||||
|
||||
@property
|
||||
def logger(self) -> logging.Logger:
|
||||
"""Return logger."""
|
||||
return logging.getLogger(__name__)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def config_flow_handler(
|
||||
hass: HomeAssistant, current_request_with_host: Any
|
||||
) -> Generator[FakeConfigFlow, None, None]:
|
||||
"""Fixture for a test config flow."""
|
||||
mock_platform(hass, f"{TEST_DOMAIN}.config_flow")
|
||||
with patch.dict(config_entries.HANDLERS, {TEST_DOMAIN: FakeConfigFlow}):
|
||||
yield FakeConfigFlow
|
||||
|
||||
|
||||
class OAuthFixture:
|
||||
"""Fixture to facilitate testing an OAuth flow."""
|
||||
|
||||
def __init__(self, hass, hass_client, aioclient_mock):
|
||||
"""Initialize OAuthFixture."""
|
||||
self.hass = hass
|
||||
self.hass_client = hass_client
|
||||
self.aioclient_mock = aioclient_mock
|
||||
self.client_id = CLIENT_ID
|
||||
|
||||
async def complete_external_step(
|
||||
self, result: data_entry_flow.FlowResult
|
||||
) -> data_entry_flow.FlowResult:
|
||||
"""Fixture method to complete the OAuth flow and return the completed result."""
|
||||
client = await self.hass_client()
|
||||
state = config_entry_oauth2_flow._encode_jwt(
|
||||
self.hass,
|
||||
{
|
||||
"flow_id": result["flow_id"],
|
||||
"redirect_uri": "https://example.com/auth/external/callback",
|
||||
},
|
||||
)
|
||||
assert result["url"] == (
|
||||
f"{AUTHORIZE_URL}?response_type=code&client_id={self.client_id}"
|
||||
"&redirect_uri=https://example.com/auth/external/callback"
|
||||
f"&state={state}"
|
||||
)
|
||||
resp = await client.get(f"/auth/external/callback?code=abcd&state={state}")
|
||||
assert resp.status == 200
|
||||
assert resp.headers["content-type"] == "text/html; charset=utf-8"
|
||||
|
||||
self.aioclient_mock.post(
|
||||
TOKEN_URL,
|
||||
json={
|
||||
"refresh_token": REFRESH_TOKEN,
|
||||
"access_token": ACCESS_TOKEN,
|
||||
"type": "bearer",
|
||||
"expires_in": 60,
|
||||
},
|
||||
)
|
||||
|
||||
result = await self.hass.config_entries.flow.async_configure(result["flow_id"])
|
||||
assert result.get("type") == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||
assert result.get("title") == self.client_id
|
||||
assert "data" in result
|
||||
assert "token" in result["data"]
|
||||
return result
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def oauth_fixture(
|
||||
hass: HomeAssistant, hass_client_no_auth: Any, aioclient_mock: Any
|
||||
) -> OAuthFixture:
|
||||
"""Fixture for testing the OAuth flow."""
|
||||
return OAuthFixture(hass, hass_client_no_auth, aioclient_mock)
|
||||
|
||||
|
||||
class Client:
|
||||
"""Test client with helper methods for application credentials websocket."""
|
||||
|
||||
def __init__(self, client):
|
||||
"""Initialize Client."""
|
||||
self.client = client
|
||||
self.id = 0
|
||||
|
||||
async def cmd(self, cmd: str, payload: dict[str, Any] = None) -> dict[str, Any]:
|
||||
"""Send a command and receive the json result."""
|
||||
self.id += 1
|
||||
await self.client.send_json(
|
||||
{
|
||||
"id": self.id,
|
||||
"type": f"{DOMAIN}/{cmd}",
|
||||
**(payload if payload is not None else {}),
|
||||
}
|
||||
)
|
||||
resp = await self.client.receive_json()
|
||||
assert resp.get("id") == self.id
|
||||
return resp
|
||||
|
||||
async def cmd_result(self, cmd: str, payload: dict[str, Any] = None) -> Any:
|
||||
"""Send a command and parse the result."""
|
||||
resp = await self.cmd(cmd, payload)
|
||||
assert resp.get("success")
|
||||
assert resp.get("type") == "result"
|
||||
return resp.get("result")
|
||||
|
||||
|
||||
ClientFixture = Callable[[], Client]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def ws_client(
|
||||
hass_ws_client: Callable[[...], ClientWebSocketResponse]
|
||||
) -> ClientFixture:
|
||||
"""Fixture for creating the test websocket client."""
|
||||
|
||||
async def create_client() -> Client:
|
||||
ws_client = await hass_ws_client()
|
||||
return Client(ws_client)
|
||||
|
||||
return create_client
|
||||
|
||||
|
||||
async def test_websocket_list_empty(ws_client: ClientFixture):
|
||||
"""Test websocket list command."""
|
||||
client = await ws_client()
|
||||
assert await client.cmd_result("list") == []
|
||||
|
||||
|
||||
async def test_websocket_create(ws_client: ClientFixture):
|
||||
"""Test websocket create command."""
|
||||
client = await ws_client()
|
||||
result = await client.cmd_result(
|
||||
"create",
|
||||
{
|
||||
CONF_DOMAIN: TEST_DOMAIN,
|
||||
CONF_CLIENT_ID: CLIENT_ID,
|
||||
CONF_CLIENT_SECRET: CLIENT_SECRET,
|
||||
},
|
||||
)
|
||||
assert result == {
|
||||
CONF_DOMAIN: TEST_DOMAIN,
|
||||
CONF_CLIENT_ID: CLIENT_ID,
|
||||
CONF_CLIENT_SECRET: CLIENT_SECRET,
|
||||
"id": ID,
|
||||
}
|
||||
|
||||
result = await client.cmd_result("list")
|
||||
assert result == [
|
||||
{
|
||||
CONF_DOMAIN: TEST_DOMAIN,
|
||||
CONF_CLIENT_ID: CLIENT_ID,
|
||||
CONF_CLIENT_SECRET: CLIENT_SECRET,
|
||||
"id": ID,
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
async def test_websocket_create_invalid_domain(ws_client: ClientFixture):
|
||||
"""Test websocket create command."""
|
||||
client = await ws_client()
|
||||
resp = await client.cmd(
|
||||
"create",
|
||||
{
|
||||
CONF_DOMAIN: "other-domain",
|
||||
CONF_CLIENT_ID: CLIENT_ID,
|
||||
CONF_CLIENT_SECRET: CLIENT_SECRET,
|
||||
},
|
||||
)
|
||||
assert not resp.get("success")
|
||||
assert "error" in resp
|
||||
assert resp["error"].get("code") == "invalid_format"
|
||||
assert (
|
||||
resp["error"].get("message")
|
||||
== "No application_credentials platform for other-domain"
|
||||
)
|
||||
|
||||
|
||||
async def test_websocket_update_not_supported(ws_client: ClientFixture):
|
||||
"""Test websocket update command in unsupported."""
|
||||
client = await ws_client()
|
||||
result = await client.cmd_result(
|
||||
"create",
|
||||
{
|
||||
CONF_DOMAIN: TEST_DOMAIN,
|
||||
CONF_CLIENT_ID: CLIENT_ID,
|
||||
CONF_CLIENT_SECRET: CLIENT_SECRET,
|
||||
},
|
||||
)
|
||||
assert result == {
|
||||
CONF_DOMAIN: TEST_DOMAIN,
|
||||
CONF_CLIENT_ID: CLIENT_ID,
|
||||
CONF_CLIENT_SECRET: CLIENT_SECRET,
|
||||
"id": ID,
|
||||
}
|
||||
|
||||
resp = await client.cmd("update", {"application_credentials_id": ID})
|
||||
assert not resp.get("success")
|
||||
assert "error" in resp
|
||||
assert resp["error"].get("code") == "invalid_format"
|
||||
assert resp["error"].get("message") == "Updates not supported"
|
||||
|
||||
|
||||
async def test_websocket_delete(ws_client: ClientFixture):
|
||||
"""Test websocket delete command."""
|
||||
client = await ws_client()
|
||||
|
||||
await client.cmd_result(
|
||||
"create",
|
||||
{
|
||||
CONF_DOMAIN: TEST_DOMAIN,
|
||||
CONF_CLIENT_ID: CLIENT_ID,
|
||||
CONF_CLIENT_SECRET: CLIENT_SECRET,
|
||||
},
|
||||
)
|
||||
assert await client.cmd_result("list") == [
|
||||
{
|
||||
CONF_DOMAIN: TEST_DOMAIN,
|
||||
CONF_CLIENT_ID: CLIENT_ID,
|
||||
CONF_CLIENT_SECRET: CLIENT_SECRET,
|
||||
"id": ID,
|
||||
}
|
||||
]
|
||||
|
||||
await client.cmd_result("delete", {"application_credentials_id": ID})
|
||||
assert await client.cmd_result("list") == []
|
||||
|
||||
|
||||
async def test_websocket_delete_item_not_found(ws_client: ClientFixture):
|
||||
"""Test websocket delete command."""
|
||||
client = await ws_client()
|
||||
|
||||
resp = await client.cmd("delete", {"application_credentials_id": ID})
|
||||
assert not resp.get("success")
|
||||
assert "error" in resp
|
||||
assert resp["error"].get("code") == "not_found"
|
||||
assert (
|
||||
resp["error"].get("message")
|
||||
== f"Unable to find application_credentials_id {ID}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("config_credential", [DEVELOPER_CREDENTIAL])
|
||||
async def test_websocket_import_config(
|
||||
ws_client: ClientFixture,
|
||||
config_credential: ClientCredential,
|
||||
import_config_credential: Any,
|
||||
):
|
||||
"""Test websocket list command for an imported credential."""
|
||||
client = await ws_client()
|
||||
|
||||
# Imported creds returned from websocket
|
||||
assert await client.cmd_result("list") == [
|
||||
{
|
||||
CONF_DOMAIN: TEST_DOMAIN,
|
||||
CONF_CLIENT_ID: CLIENT_ID,
|
||||
CONF_CLIENT_SECRET: CLIENT_SECRET,
|
||||
"id": ID,
|
||||
CONF_AUTH_DOMAIN: TEST_DOMAIN,
|
||||
}
|
||||
]
|
||||
|
||||
# Imported credential can be deleted
|
||||
await client.cmd_result("delete", {"application_credentials_id": ID})
|
||||
assert await client.cmd_result("list") == []
|
||||
|
||||
|
||||
@pytest.mark.parametrize("config_credential", [DEVELOPER_CREDENTIAL])
|
||||
async def test_import_duplicate_credentials(
|
||||
hass: HomeAssistant,
|
||||
ws_client: ClientFixture,
|
||||
config_credential: ClientCredential,
|
||||
import_config_credential: Any,
|
||||
):
|
||||
"""Exercise duplicate credentials are ignored."""
|
||||
|
||||
# Import the test credential again and verify it is not imported twice
|
||||
await async_import_client_credential(hass, TEST_DOMAIN, DEVELOPER_CREDENTIAL)
|
||||
client = await ws_client()
|
||||
assert await client.cmd_result("list") == [
|
||||
{
|
||||
CONF_DOMAIN: TEST_DOMAIN,
|
||||
CONF_CLIENT_ID: CLIENT_ID,
|
||||
CONF_CLIENT_SECRET: CLIENT_SECRET,
|
||||
"id": ID,
|
||||
CONF_AUTH_DOMAIN: TEST_DOMAIN,
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
async def test_config_flow_no_credentials(hass):
|
||||
"""Test config flow base case with no credentials registered."""
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
TEST_DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
assert result.get("type") == data_entry_flow.RESULT_TYPE_ABORT
|
||||
assert result.get("reason") == "missing_configuration"
|
||||
|
||||
|
||||
async def test_config_flow_other_domain(
|
||||
hass: HomeAssistant,
|
||||
ws_client: ClientFixture,
|
||||
authorization_server: AuthorizationServer,
|
||||
):
|
||||
"""Test config flow ignores credentials for another domain."""
|
||||
await setup_application_credentials_integration(
|
||||
hass,
|
||||
"other_domain",
|
||||
authorization_server,
|
||||
)
|
||||
client = await ws_client()
|
||||
await client.cmd_result(
|
||||
"create",
|
||||
{
|
||||
CONF_DOMAIN: "other_domain",
|
||||
CONF_CLIENT_ID: CLIENT_ID,
|
||||
CONF_CLIENT_SECRET: CLIENT_SECRET,
|
||||
},
|
||||
)
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
TEST_DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
assert result.get("type") == data_entry_flow.RESULT_TYPE_ABORT
|
||||
assert result.get("reason") == "missing_configuration"
|
||||
|
||||
|
||||
async def test_config_flow(
|
||||
hass: HomeAssistant,
|
||||
ws_client: ClientFixture,
|
||||
oauth_fixture: OAuthFixture,
|
||||
):
|
||||
"""Test config flow with application credential registered."""
|
||||
client = await ws_client()
|
||||
|
||||
await client.cmd_result(
|
||||
"create",
|
||||
{
|
||||
CONF_DOMAIN: TEST_DOMAIN,
|
||||
CONF_CLIENT_ID: CLIENT_ID,
|
||||
CONF_CLIENT_SECRET: CLIENT_SECRET,
|
||||
},
|
||||
)
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
TEST_DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
assert result.get("type") == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP
|
||||
result = await oauth_fixture.complete_external_step(result)
|
||||
assert (
|
||||
result["data"].get("auth_implementation") == "fake_integration_some_client_id"
|
||||
)
|
||||
|
||||
# Verify it is not possible to delete an in-use config entry
|
||||
resp = await client.cmd("delete", {"application_credentials_id": ID})
|
||||
assert not resp.get("success")
|
||||
assert "error" in resp
|
||||
assert resp["error"].get("code") == "unknown_error"
|
||||
|
||||
|
||||
async def test_config_flow_multiple_entries(
|
||||
hass: HomeAssistant,
|
||||
ws_client: ClientFixture,
|
||||
oauth_fixture: OAuthFixture,
|
||||
):
|
||||
"""Test config flow with multiple application credentials registered."""
|
||||
client = await ws_client()
|
||||
|
||||
await client.cmd_result(
|
||||
"create",
|
||||
{
|
||||
CONF_DOMAIN: TEST_DOMAIN,
|
||||
CONF_CLIENT_ID: CLIENT_ID,
|
||||
CONF_CLIENT_SECRET: CLIENT_SECRET,
|
||||
},
|
||||
)
|
||||
await client.cmd_result(
|
||||
"create",
|
||||
{
|
||||
CONF_DOMAIN: TEST_DOMAIN,
|
||||
CONF_CLIENT_ID: CLIENT_ID + "2",
|
||||
CONF_CLIENT_SECRET: CLIENT_SECRET + "2",
|
||||
},
|
||||
)
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
TEST_DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
assert result.get("type") == data_entry_flow.RESULT_TYPE_FORM
|
||||
assert result.get("step_id") == "pick_implementation"
|
||||
|
||||
result = await hass.config_entries.flow.async_configure(
|
||||
result["flow_id"],
|
||||
user_input={"implementation": "fake_integration_some_client_id2"},
|
||||
)
|
||||
assert result.get("type") == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP
|
||||
oauth_fixture.client_id = CLIENT_ID + "2"
|
||||
result = await oauth_fixture.complete_external_step(result)
|
||||
assert (
|
||||
result["data"].get("auth_implementation") == "fake_integration_some_client_id2"
|
||||
)
|
||||
|
||||
|
||||
async def test_config_flow_create_delete_credential(
|
||||
hass: HomeAssistant,
|
||||
ws_client: ClientFixture,
|
||||
oauth_fixture: OAuthFixture,
|
||||
):
|
||||
"""Test adding and deleting a credential unregisters from the config flow."""
|
||||
client = await ws_client()
|
||||
|
||||
await client.cmd_result(
|
||||
"create",
|
||||
{
|
||||
CONF_DOMAIN: TEST_DOMAIN,
|
||||
CONF_CLIENT_ID: CLIENT_ID,
|
||||
CONF_CLIENT_SECRET: CLIENT_SECRET,
|
||||
},
|
||||
)
|
||||
await client.cmd("delete", {"application_credentials_id": ID})
|
||||
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
TEST_DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
assert result.get("type") == data_entry_flow.RESULT_TYPE_ABORT
|
||||
assert result.get("reason") == "missing_configuration"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("config_credential", [DEVELOPER_CREDENTIAL])
|
||||
async def test_config_flow_with_config_credential(
|
||||
hass,
|
||||
hass_client_no_auth,
|
||||
aioclient_mock,
|
||||
oauth_fixture,
|
||||
config_credential,
|
||||
import_config_credential,
|
||||
):
|
||||
"""Test config flow with application credential registered."""
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
TEST_DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
assert result.get("type") == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP
|
||||
result = await oauth_fixture.complete_external_step(result)
|
||||
# Uses the imported auth domain for compatibility
|
||||
assert result["data"].get("auth_implementation") == TEST_DOMAIN
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mock_application_credentials_integration", [None])
|
||||
async def test_import_without_setup(hass, config_credential):
|
||||
"""Test import of credentials without setting up the integration."""
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
await async_import_client_credential(hass, TEST_DOMAIN, config_credential)
|
||||
|
||||
# Config flow does not have authentication
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
TEST_DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
assert result.get("type") == data_entry_flow.RESULT_TYPE_ABORT
|
||||
assert result.get("reason") == "missing_configuration"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mock_application_credentials_integration", [None])
|
||||
async def test_websocket_without_platform(
|
||||
hass: HomeAssistant, ws_client: ClientFixture
|
||||
):
|
||||
"""Test an integration without the application credential platform."""
|
||||
assert await async_setup_component(hass, "application_credentials", {})
|
||||
hass.config.components.add(TEST_DOMAIN)
|
||||
|
||||
client = await ws_client()
|
||||
resp = await client.cmd(
|
||||
"create",
|
||||
{
|
||||
CONF_DOMAIN: TEST_DOMAIN,
|
||||
CONF_CLIENT_ID: CLIENT_ID,
|
||||
CONF_CLIENT_SECRET: CLIENT_SECRET,
|
||||
},
|
||||
)
|
||||
assert not resp.get("success")
|
||||
assert "error" in resp
|
||||
assert resp["error"].get("code") == "invalid_format"
|
||||
|
||||
# Config flow does not have authentication
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
TEST_DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
assert result.get("type") == data_entry_flow.RESULT_TYPE_ABORT
|
||||
assert result.get("reason") == "missing_configuration"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("mock_application_credentials_integration", [None])
|
||||
async def test_websocket_without_authorization_server(
|
||||
hass: HomeAssistant, ws_client: ClientFixture
|
||||
):
|
||||
"""Test platform with incorrect implementation."""
|
||||
assert await async_setup_component(hass, "application_credentials", {})
|
||||
hass.config.components.add(TEST_DOMAIN)
|
||||
|
||||
# Platform does not implemenent async_get_authorization_server
|
||||
platform = Mock()
|
||||
del platform.async_get_authorization_server
|
||||
mock_platform(
|
||||
hass,
|
||||
f"{TEST_DOMAIN}.application_credentials",
|
||||
platform,
|
||||
)
|
||||
|
||||
client = await ws_client()
|
||||
resp = await client.cmd(
|
||||
"create",
|
||||
{
|
||||
CONF_DOMAIN: TEST_DOMAIN,
|
||||
CONF_CLIENT_ID: CLIENT_ID,
|
||||
CONF_CLIENT_SECRET: CLIENT_SECRET,
|
||||
},
|
||||
)
|
||||
assert not resp.get("success")
|
||||
assert "error" in resp
|
||||
assert resp["error"].get("code") == "invalid_format"
|
||||
|
||||
# Config flow does not have authentication
|
||||
with pytest.raises(ValueError):
|
||||
await hass.config_entries.flow.async_init(
|
||||
TEST_DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||
)
|
||||
|
||||
|
||||
async def test_websocket_integration_list(ws_client: ClientFixture):
|
||||
"""Test websocket integration list command."""
|
||||
client = await ws_client()
|
||||
with patch(
|
||||
"homeassistant.components.application_credentials.APPLICATION_CREDENTIALS",
|
||||
["example1", "example2"],
|
||||
):
|
||||
assert await client.cmd_result("config") == {
|
||||
"domains": ["example1", "example2"]
|
||||
}
|
@ -537,11 +537,11 @@ async def test_implementation_provider(hass, local_impl):
|
||||
hass, mock_domain_with_impl
|
||||
) == {TEST_DOMAIN: local_impl}
|
||||
|
||||
provider_source = {}
|
||||
provider_source = []
|
||||
|
||||
async def async_provide_implementation(hass, domain):
|
||||
"""Mock implementation provider."""
|
||||
return provider_source.get(domain)
|
||||
return provider_source
|
||||
|
||||
config_entry_oauth2_flow.async_add_implementation_provider(
|
||||
hass, "cloud", async_provide_implementation
|
||||
@ -551,15 +551,29 @@ async def test_implementation_provider(hass, local_impl):
|
||||
hass, mock_domain_with_impl
|
||||
) == {TEST_DOMAIN: local_impl}
|
||||
|
||||
provider_source[
|
||||
mock_domain_with_impl
|
||||
] = config_entry_oauth2_flow.LocalOAuth2Implementation(
|
||||
hass, "cloud", CLIENT_ID, CLIENT_SECRET, AUTHORIZE_URL, TOKEN_URL
|
||||
provider_source.append(
|
||||
config_entry_oauth2_flow.LocalOAuth2Implementation(
|
||||
hass, "cloud", CLIENT_ID, CLIENT_SECRET, AUTHORIZE_URL, TOKEN_URL
|
||||
)
|
||||
)
|
||||
|
||||
assert await config_entry_oauth2_flow.async_get_implementations(
|
||||
hass, mock_domain_with_impl
|
||||
) == {TEST_DOMAIN: local_impl, "cloud": provider_source[mock_domain_with_impl]}
|
||||
) == {TEST_DOMAIN: local_impl, "cloud": provider_source[0]}
|
||||
|
||||
provider_source.append(
|
||||
config_entry_oauth2_flow.LocalOAuth2Implementation(
|
||||
hass, "other", CLIENT_ID, CLIENT_SECRET, AUTHORIZE_URL, TOKEN_URL
|
||||
)
|
||||
)
|
||||
|
||||
assert await config_entry_oauth2_flow.async_get_implementations(
|
||||
hass, mock_domain_with_impl
|
||||
) == {
|
||||
TEST_DOMAIN: local_impl,
|
||||
"cloud": provider_source[0],
|
||||
"other": provider_source[1],
|
||||
}
|
||||
|
||||
|
||||
async def test_oauth_session_refresh_failure(
|
||||
|
Loading…
x
Reference in New Issue
Block a user