Add application credentials platform for google calendar integration (#71808)

* Add google application_credentials platform

* Further simplify custom auth implementation overrides

* Add test coverage in application_credentials

* Simplify wording in a comment

* Remove unused imports accidentally left from merge

* Wrap lines that are too long for style guide

* Move application credential loading to only where it is needed

* Leave CLIENT_ID and CLIENT_SECRET as required.
This commit is contained in:
Allen Porter 2022-05-14 10:27:47 -07:00 committed by GitHub
parent 656e88faec
commit 355445db2d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 258 additions and 47 deletions

View File

@ -151,7 +151,10 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
async def async_import_client_credential( async def async_import_client_credential(
hass: HomeAssistant, domain: str, credential: ClientCredential hass: HomeAssistant,
domain: str,
credential: ClientCredential,
auth_domain: str = None,
) -> None: ) -> None:
"""Import an existing credential from configuration.yaml.""" """Import an existing credential from configuration.yaml."""
if DOMAIN not in hass.data: if DOMAIN not in hass.data:
@ -161,7 +164,7 @@ async def async_import_client_credential(
CONF_DOMAIN: domain, CONF_DOMAIN: domain,
CONF_CLIENT_ID: credential.client_id, CONF_CLIENT_ID: credential.client_id,
CONF_CLIENT_SECRET: credential.client_secret, CONF_CLIENT_SECRET: credential.client_secret,
CONF_AUTH_DOMAIN: domain, CONF_AUTH_DOMAIN: auth_domain if auth_domain else domain,
} }
await storage_collection.async_import_item(item) await storage_collection.async_import_item(item)
@ -169,6 +172,23 @@ async def async_import_client_credential(
class AuthImplementation(config_entry_oauth2_flow.LocalOAuth2Implementation): class AuthImplementation(config_entry_oauth2_flow.LocalOAuth2Implementation):
"""Application Credentials local oauth2 implementation.""" """Application Credentials local oauth2 implementation."""
def __init__(
self,
hass: HomeAssistant,
auth_domain: str,
credential: ClientCredential,
authorization_server: AuthorizationServer,
) -> None:
"""Initialize AuthImplementation."""
super().__init__(
hass,
auth_domain,
credential.client_id,
credential.client_secret,
authorization_server.authorize_url,
authorization_server.token_url,
)
@property @property
def name(self) -> str: def name(self) -> str:
"""Name of the implementation.""" """Name of the implementation."""
@ -184,29 +204,38 @@ async def _async_provide_implementation(
if not platform: if not platform:
return [] return []
authorization_server = await platform.async_get_authorization_server(hass)
storage_collection = hass.data[DOMAIN][DATA_STORAGE] storage_collection = hass.data[DOMAIN][DATA_STORAGE]
credentials = storage_collection.async_client_credentials(domain) credentials = storage_collection.async_client_credentials(domain)
if hasattr(platform, "async_get_auth_implementation"):
return [
await platform.async_get_auth_implementation(hass, auth_domain, credential)
for auth_domain, credential in credentials.items()
]
authorization_server = await platform.async_get_authorization_server(hass)
return [ return [
AuthImplementation( AuthImplementation(hass, auth_domain, credential, authorization_server)
hass,
auth_domain,
credential.client_id,
credential.client_secret,
authorization_server.authorize_url,
authorization_server.token_url,
)
for auth_domain, credential in credentials.items() for auth_domain, credential in credentials.items()
] ]
class ApplicationCredentialsProtocol(Protocol): class ApplicationCredentialsProtocol(Protocol):
"""Define the format that application_credentials platforms can have.""" """Define the format that application_credentials platforms may have.
Most platforms typically just implement async_get_authorization_server, and
the default oauth implementation will be used. Otherwise a platform may
implement async_get_auth_implementation to give their use a custom
AbstractOAuth2Implementation.
"""
async def async_get_authorization_server( async def async_get_authorization_server(
self, hass: HomeAssistant self, hass: HomeAssistant
) -> AuthorizationServer: ) -> AuthorizationServer:
"""Return authorization server.""" """Return authorization server, for the default auth implementation."""
async def async_get_auth_implementation(
self, hass: HomeAssistant, auth_domain: str, credential: ClientCredential
) -> config_entry_oauth2_flow.AbstractOAuth2Implementation:
"""Return a custom auth implementation."""
async def _get_platform( async def _get_platform(
@ -227,9 +256,12 @@ async def _get_platform(
err, err,
) )
return None return None
if not hasattr(platform, "async_get_authorization_server"): if not hasattr(platform, "async_get_authorization_server") and not hasattr(
platform, "async_get_auth_implementation"
):
raise ValueError( raise ValueError(
f"Integration '{integration_domain}' platform application_credentials did not implement 'async_get_authorization_server'" f"Integration '{integration_domain}' platform {DOMAIN} did not "
f"implement 'async_get_authorization_server' or 'async_get_auth_implementation'"
) )
return platform return platform

View File

@ -17,6 +17,10 @@ from voluptuous.error import Error as VoluptuousError
import yaml import yaml
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.components.application_credentials import (
ClientCredential,
async_import_client_credential,
)
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import ( from homeassistant.const import (
CONF_CLIENT_ID, CONF_CLIENT_ID,
@ -39,12 +43,12 @@ from homeassistant.helpers.dispatcher import async_dispatcher_send
from homeassistant.helpers.entity import generate_entity_id from homeassistant.helpers.entity import generate_entity_id
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from . import config_flow from .api import ApiAuthImpl, get_feature_access
from .api import ApiAuthImpl, DeviceAuth, get_feature_access
from .const import ( from .const import (
CONF_CALENDAR_ACCESS, CONF_CALENDAR_ACCESS,
DATA_CONFIG, DATA_CONFIG,
DATA_SERVICE, DATA_SERVICE,
DEVICE_AUTH_IMPL,
DISCOVER_CALENDAR, DISCOVER_CALENDAR,
DOMAIN, DOMAIN,
FeatureAccess, FeatureAccess,
@ -159,14 +163,17 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the Google component.""" """Set up the Google component."""
conf = config.get(DOMAIN, {}) conf = config.get(DOMAIN, {})
hass.data[DOMAIN] = {DATA_CONFIG: conf} hass.data[DOMAIN] = {DATA_CONFIG: conf}
config_flow.OAuth2FlowHandler.async_register_implementation(
hass, if CONF_CLIENT_ID in conf and CONF_CLIENT_SECRET in conf:
DeviceAuth( await async_import_client_credential(
hass, hass,
conf[CONF_CLIENT_ID], DOMAIN,
conf[CONF_CLIENT_SECRET], ClientCredential(
), conf[CONF_CLIENT_ID],
) conf[CONF_CLIENT_SECRET],
),
DEVICE_AUTH_IMPL,
)
# Import credentials from the old token file into the new way as # Import credentials from the old token file into the new way as
# a ConfigEntry managed by home assistant. # a ConfigEntry managed by home assistant.

View File

@ -10,7 +10,6 @@ from typing import Any
import aiohttp import aiohttp
from gcal_sync.auth import AbstractAuth from gcal_sync.auth import AbstractAuth
import oauth2client
from oauth2client.client import ( from oauth2client.client import (
Credentials, Credentials,
DeviceFlowInfo, DeviceFlowInfo,
@ -19,6 +18,7 @@ from oauth2client.client import (
OAuth2WebServerFlow, OAuth2WebServerFlow,
) )
from homeassistant.components.application_credentials import AuthImplementation
from homeassistant.core import CALLBACK_TYPE, HomeAssistant from homeassistant.core import CALLBACK_TYPE, HomeAssistant
from homeassistant.helpers import config_entry_oauth2_flow from homeassistant.helpers import config_entry_oauth2_flow
from homeassistant.helpers.event import async_track_time_interval from homeassistant.helpers.event import async_track_time_interval
@ -28,7 +28,6 @@ from .const import (
CONF_CALENDAR_ACCESS, CONF_CALENDAR_ACCESS,
DATA_CONFIG, DATA_CONFIG,
DEFAULT_FEATURE_ACCESS, DEFAULT_FEATURE_ACCESS,
DEVICE_AUTH_IMPL,
DOMAIN, DOMAIN,
FeatureAccess, FeatureAccess,
) )
@ -44,20 +43,9 @@ class OAuthError(Exception):
"""OAuth related error.""" """OAuth related error."""
class DeviceAuth(config_entry_oauth2_flow.LocalOAuth2Implementation): class DeviceAuth(AuthImplementation):
"""OAuth implementation for Device Auth.""" """OAuth implementation for Device Auth."""
def __init__(self, hass: HomeAssistant, client_id: str, client_secret: str) -> None:
"""Initialize InstalledAppAuth."""
super().__init__(
hass,
DEVICE_AUTH_IMPL,
client_id,
client_secret,
oauth2client.GOOGLE_AUTH_URI,
oauth2client.GOOGLE_TOKEN_URI,
)
async def async_resolve_external_data(self, external_data: Any) -> dict: async def async_resolve_external_data(self, external_data: Any) -> dict:
"""Resolve a Google API Credentials object to Home Assistant token.""" """Resolve a Google API Credentials object to Home Assistant token."""
creds: Credentials = external_data[DEVICE_AUTH_CREDS] creds: Credentials = external_data[DEVICE_AUTH_CREDS]

View File

@ -0,0 +1,23 @@
"""application_credentials platform for nest."""
import oauth2client
from homeassistant.components.application_credentials import (
AuthorizationServer,
ClientCredential,
)
from homeassistant.core import HomeAssistant
from homeassistant.helpers import config_entry_oauth2_flow
from .api import DeviceAuth
AUTHORIZATION_SERVER = AuthorizationServer(
oauth2client.GOOGLE_AUTH_URI, oauth2client.GOOGLE_TOKEN_URI
)
async def async_get_auth_implementation(
hass: HomeAssistant, auth_domain: str, credential: ClientCredential
) -> config_entry_oauth2_flow.AbstractOAuth2Implementation:
"""Return auth implementation."""
return DeviceAuth(hass, auth_domain, credential, AUTHORIZATION_SERVER)

View File

@ -2,7 +2,7 @@
"domain": "google", "domain": "google",
"name": "Google Calendars", "name": "Google Calendars",
"config_flow": true, "config_flow": true,
"dependencies": ["auth"], "dependencies": ["application_credentials"],
"documentation": "https://www.home-assistant.io/integrations/calendar.google/", "documentation": "https://www.home-assistant.io/integrations/calendar.google/",
"requirements": ["gcal-sync==0.7.1", "oauth2client==4.1.3"], "requirements": ["gcal-sync==0.7.1", "oauth2client==4.1.3"],
"codeowners": ["@allenporter"], "codeowners": ["@allenporter"],

View File

@ -6,5 +6,6 @@ To update, run python3 -m script.hassfest
# fmt: off # fmt: off
APPLICATION_CREDENTIALS = [ APPLICATION_CREDENTIALS = [
"google",
"xbox" "xbox"
] ]

View File

@ -14,6 +14,7 @@ from homeassistant import config_entries, data_entry_flow
from homeassistant.components.application_credentials import ( from homeassistant.components.application_credentials import (
CONF_AUTH_DOMAIN, CONF_AUTH_DOMAIN,
DOMAIN, DOMAIN,
AuthImplementation,
AuthorizationServer, AuthorizationServer,
ClientCredential, ClientCredential,
async_import_client_credential, async_import_client_credential,
@ -64,12 +65,14 @@ async def setup_application_credentials_integration(
) -> None: ) -> None:
"""Set up a fake application_credentials integration.""" """Set up a fake application_credentials integration."""
hass.config.components.add(domain) hass.config.components.add(domain)
mock_platform_impl = Mock(
async_get_authorization_server=AsyncMock(return_value=authorization_server),
)
del mock_platform_impl.async_get_auth_implementation # return False on hasattr
mock_platform( mock_platform(
hass, hass,
f"{domain}.application_credentials", f"{domain}.application_credentials",
Mock( mock_platform_impl,
async_get_authorization_server=AsyncMock(return_value=authorization_server),
),
) )
@ -585,6 +588,7 @@ async def test_websocket_without_authorization_server(
# Platform does not implemenent async_get_authorization_server # Platform does not implemenent async_get_authorization_server
platform = Mock() platform = Mock()
del platform.async_get_authorization_server del platform.async_get_authorization_server
del platform.async_get_auth_implementation
mock_platform( mock_platform(
hass, hass,
f"{TEST_DOMAIN}.application_credentials", f"{TEST_DOMAIN}.application_credentials",
@ -611,6 +615,45 @@ async def test_websocket_without_authorization_server(
) )
@pytest.mark.parametrize("config_credential", [DEVELOPER_CREDENTIAL])
async def test_platform_with_auth_implementation(
hass,
hass_client_no_auth,
aioclient_mock,
oauth_fixture,
config_credential,
import_config_credential,
authorization_server,
):
"""Test config flow with custom OAuth2 implementation."""
assert await async_setup_component(hass, "application_credentials", {})
hass.config.components.add(TEST_DOMAIN)
async def get_auth_impl(
hass: HomeAssistant, auth_domain: str, credential: ClientCredential
) -> config_entry_oauth2_flow.AbstractOAuth2Implementation:
return AuthImplementation(hass, auth_domain, credential, authorization_server)
mock_platform_impl = Mock(
async_get_auth_implementation=get_auth_impl,
)
del mock_platform_impl.async_get_authorization_server
mock_platform(
hass,
f"{TEST_DOMAIN}.application_credentials",
mock_platform_impl,
)
result = await hass.config_entries.flow.async_init(
TEST_DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result.get("type") == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP
result = await oauth_fixture.complete_external_step(result)
# Uses the imported auth domain for compatibility
assert result["data"].get("auth_implementation") == TEST_DOMAIN
async def test_websocket_integration_list(ws_client: ClientFixture): async def test_websocket_integration_list(ws_client: ClientFixture):
"""Test websocket integration list command.""" """Test websocket integration list command."""
client = await ws_client() client = await ws_client()

View File

@ -171,7 +171,8 @@ def config_entry_token_expiry(token_expiry: datetime.datetime) -> float:
@pytest.fixture @pytest.fixture
def config_entry( def config_entry(
token_scopes: list[str], config_entry_token_expiry: float token_scopes: list[str],
config_entry_token_expiry: float,
) -> MockConfigEntry: ) -> MockConfigEntry:
"""Fixture to create a config entry for the integration.""" """Fixture to create a config entry for the integration."""
return MockConfigEntry( return MockConfigEntry(
@ -291,7 +292,7 @@ def google_config(google_config_track_new: bool | None) -> dict[str, Any]:
@pytest.fixture @pytest.fixture
def config(google_config: dict[str, Any]) -> dict[str, Any]: def config(google_config: dict[str, Any]) -> dict[str, Any]:
"""Fixture for overriding component config.""" """Fixture for overriding component config."""
return {DOMAIN: google_config} return {DOMAIN: google_config} if google_config else {}
@pytest.fixture @pytest.fixture

View File

@ -1,6 +1,7 @@
"""Test the google config flow.""" """Test the google config flow."""
import datetime import datetime
from typing import Any
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from oauth2client.client import ( from oauth2client.client import (
@ -11,6 +12,10 @@ from oauth2client.client import (
import pytest import pytest
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.components.application_credentials import (
ClientCredential,
async_import_client_credential,
)
from homeassistant.components.google.const import DOMAIN from homeassistant.components.google.const import DOMAIN
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import config_entry_oauth2_flow from homeassistant.helpers import config_entry_oauth2_flow
@ -65,7 +70,7 @@ async def fire_alarm(hass, point_in_time):
await hass.async_block_till_done() await hass.async_block_till_done()
async def test_full_flow( async def test_full_flow_yaml_creds(
hass: HomeAssistant, hass: HomeAssistant,
mock_code_flow: Mock, mock_code_flow: Mock,
mock_exchange: Mock, mock_exchange: Mock,
@ -94,7 +99,7 @@ async def test_full_flow(
) )
assert result.get("type") == "create_entry" assert result.get("type") == "create_entry"
assert result.get("title") == "Configuration.yaml" assert result.get("title") == "client-id"
assert "data" in result assert "data" in result
data = result["data"] data = result["data"]
assert "token" in data assert "token" in data
@ -121,6 +126,68 @@ async def test_full_flow(
assert len(entries) == 1 assert len(entries) == 1
@pytest.mark.parametrize("google_config", [None])
async def test_full_flow_application_creds(
hass: HomeAssistant,
mock_code_flow: Mock,
mock_exchange: Mock,
config: dict[str, Any],
component_setup: ComponentSetup,
) -> None:
"""Test successful creds setup."""
assert await component_setup()
await async_import_client_credential(
hass, DOMAIN, ClientCredential("client-id", "client-secret"), "imported-cred"
)
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result.get("type") == "progress"
assert result.get("step_id") == "auth"
assert "description_placeholders" in result
assert "url" in result["description_placeholders"]
with patch(
"homeassistant.components.google.async_setup_entry", return_value=True
) as mock_setup:
# Run one tick to invoke the credential exchange check
now = utcnow()
await fire_alarm(hass, now + CODE_CHECK_ALARM_TIMEDELTA)
await hass.async_block_till_done()
result = await hass.config_entries.flow.async_configure(
flow_id=result["flow_id"]
)
assert result.get("type") == "create_entry"
assert result.get("title") == "client-id"
assert "data" in result
data = result["data"]
assert "token" in data
assert 0 < data["token"]["expires_in"] < 8 * 86400
assert (
datetime.datetime.now().timestamp()
<= data["token"]["expires_at"]
< (datetime.datetime.now() + datetime.timedelta(days=8)).timestamp()
)
data["token"].pop("expires_at")
data["token"].pop("expires_in")
assert data == {
"auth_implementation": "imported-cred",
"token": {
"access_token": "ACCESS_TOKEN",
"refresh_token": "REFRESH_TOKEN",
"scope": "https://www.googleapis.com/auth/calendar",
"token_type": "Bearer",
},
}
assert len(mock_setup.mock_calls) == 1
entries = hass.config_entries.async_entries(DOMAIN)
assert len(entries) == 1
async def test_code_error( async def test_code_error(
hass: HomeAssistant, hass: HomeAssistant,
mock_code_flow: Mock, mock_code_flow: Mock,
@ -211,7 +278,7 @@ async def test_exchange_error(
) )
assert result.get("type") == "create_entry" assert result.get("type") == "create_entry"
assert result.get("title") == "Configuration.yaml" assert result.get("title") == "client-id"
assert "data" in result assert "data" in result
data = result["data"] data = result["data"]
assert "token" in data assert "token" in data
@ -263,6 +330,21 @@ async def test_missing_configuration(
assert result.get("reason") == "missing_configuration" assert result.get("reason") == "missing_configuration"
@pytest.mark.parametrize("google_config", [None])
async def test_missing_configuration_yaml_empty(
hass: HomeAssistant,
component_setup: ComponentSetup,
) -> None:
"""Test setup with an empty yaml configuration and no credentials."""
assert await component_setup()
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result.get("type") == "abort"
assert result.get("reason") == "missing_configuration"
async def test_wrong_configuration( async def test_wrong_configuration(
hass: HomeAssistant, hass: HomeAssistant,
) -> None: ) -> None:

View File

@ -10,6 +10,10 @@ from unittest.mock import patch
import pytest import pytest
from homeassistant.components.application_credentials import (
ClientCredential,
async_import_client_credential,
)
from homeassistant.components.google import ( from homeassistant.components.google import (
DOMAIN, DOMAIN,
SERVICE_ADD_EVENT, SERVICE_ADD_EVENT,
@ -18,6 +22,7 @@ from homeassistant.components.google import (
from homeassistant.config_entries import ConfigEntryState from homeassistant.config_entries import ConfigEntryState
from homeassistant.const import STATE_OFF from homeassistant.const import STATE_OFF
from homeassistant.core import HomeAssistant, State from homeassistant.core import HomeAssistant, State
from homeassistant.setup import async_setup_component
from homeassistant.util.dt import utcnow from homeassistant.util.dt import utcnow
from .conftest import ( from .conftest import (
@ -224,6 +229,35 @@ async def test_found_calendar_from_api(
assert not hass.states.get(TEST_YAML_ENTITY) assert not hass.states.get(TEST_YAML_ENTITY)
@pytest.mark.parametrize("calendars_config,google_config", [([], {})])
async def test_load_application_credentials(
hass: HomeAssistant,
component_setup: ComponentSetup,
mock_calendars_yaml: None,
mock_calendars_list: ApiResult,
test_api_calendar: dict[str, Any],
mock_events_list: ApiResult,
setup_config_entry: MockConfigEntry,
) -> None:
"""Test loading an application credentials and a config entry."""
assert await async_setup_component(hass, "application_credentials", {})
await async_import_client_credential(
hass, DOMAIN, ClientCredential("client-id", "client-secret"), "device_auth"
)
mock_calendars_list({"items": [test_api_calendar]})
mock_events_list({})
assert await component_setup()
state = hass.states.get(TEST_API_ENTITY)
assert state
assert state.name == TEST_API_ENTITY_NAME
assert state.state == STATE_OFF
# No yaml config loaded that overwrites the entity name
assert not hass.states.get(TEST_YAML_ENTITY)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"calendars_config_track,expected_state", "calendars_config_track,expected_state",
[ [