mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 05:07:41 +00:00
Add display name for application credentials (#72053)
* Add display name for application credentials * Rename display name to name * Improve test coverage for importing a named credential * Add a default credential name on import
This commit is contained in:
parent
edd7a3427c
commit
a6402697bb
@ -15,7 +15,13 @@ import voluptuous as vol
|
|||||||
|
|
||||||
from homeassistant.components import websocket_api
|
from homeassistant.components import websocket_api
|
||||||
from homeassistant.components.websocket_api.connection import ActiveConnection
|
from homeassistant.components.websocket_api.connection import ActiveConnection
|
||||||
from homeassistant.const import CONF_CLIENT_ID, CONF_CLIENT_SECRET, CONF_DOMAIN, CONF_ID
|
from homeassistant.const import (
|
||||||
|
CONF_CLIENT_ID,
|
||||||
|
CONF_CLIENT_SECRET,
|
||||||
|
CONF_DOMAIN,
|
||||||
|
CONF_ID,
|
||||||
|
CONF_NAME,
|
||||||
|
)
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.exceptions import HomeAssistantError
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
from homeassistant.helpers import collection, config_entry_oauth2_flow
|
from homeassistant.helpers import collection, config_entry_oauth2_flow
|
||||||
@ -39,12 +45,14 @@ STORAGE_KEY = DOMAIN
|
|||||||
STORAGE_VERSION = 1
|
STORAGE_VERSION = 1
|
||||||
DATA_STORAGE = "storage"
|
DATA_STORAGE = "storage"
|
||||||
CONF_AUTH_DOMAIN = "auth_domain"
|
CONF_AUTH_DOMAIN = "auth_domain"
|
||||||
|
DEFAULT_IMPORT_NAME = "Import from configuration.yaml"
|
||||||
|
|
||||||
CREATE_FIELDS = {
|
CREATE_FIELDS = {
|
||||||
vol.Required(CONF_DOMAIN): cv.string,
|
vol.Required(CONF_DOMAIN): cv.string,
|
||||||
vol.Required(CONF_CLIENT_ID): cv.string,
|
vol.Required(CONF_CLIENT_ID): cv.string,
|
||||||
vol.Required(CONF_CLIENT_SECRET): cv.string,
|
vol.Required(CONF_CLIENT_SECRET): cv.string,
|
||||||
vol.Optional(CONF_AUTH_DOMAIN): cv.string,
|
vol.Optional(CONF_AUTH_DOMAIN): cv.string,
|
||||||
|
vol.Optional(CONF_NAME): cv.string,
|
||||||
}
|
}
|
||||||
UPDATE_FIELDS: dict = {} # Not supported
|
UPDATE_FIELDS: dict = {} # Not supported
|
||||||
|
|
||||||
@ -55,6 +63,7 @@ class ClientCredential:
|
|||||||
|
|
||||||
client_id: str
|
client_id: str
|
||||||
client_secret: str
|
client_secret: str
|
||||||
|
name: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -122,7 +131,9 @@ class ApplicationCredentialsStorageCollection(collection.StorageCollection):
|
|||||||
item[CONF_AUTH_DOMAIN] if CONF_AUTH_DOMAIN in item else item[CONF_ID]
|
item[CONF_AUTH_DOMAIN] if CONF_AUTH_DOMAIN in item else item[CONF_ID]
|
||||||
)
|
)
|
||||||
credentials[auth_domain] = ClientCredential(
|
credentials[auth_domain] = ClientCredential(
|
||||||
item[CONF_CLIENT_ID], item[CONF_CLIENT_SECRET]
|
client_id=item[CONF_CLIENT_ID],
|
||||||
|
client_secret=item[CONF_CLIENT_SECRET],
|
||||||
|
name=item.get(CONF_NAME),
|
||||||
)
|
)
|
||||||
return credentials
|
return credentials
|
||||||
|
|
||||||
@ -169,6 +180,7 @@ async def async_import_client_credential(
|
|||||||
CONF_CLIENT_SECRET: credential.client_secret,
|
CONF_CLIENT_SECRET: credential.client_secret,
|
||||||
CONF_AUTH_DOMAIN: auth_domain if auth_domain else domain,
|
CONF_AUTH_DOMAIN: auth_domain if auth_domain else domain,
|
||||||
}
|
}
|
||||||
|
item[CONF_NAME] = credential.name if credential.name else DEFAULT_IMPORT_NAME
|
||||||
await storage_collection.async_import_item(item)
|
await storage_collection.async_import_item(item)
|
||||||
|
|
||||||
|
|
||||||
@ -191,11 +203,12 @@ class AuthImplementation(config_entry_oauth2_flow.LocalOAuth2Implementation):
|
|||||||
authorization_server.authorize_url,
|
authorization_server.authorize_url,
|
||||||
authorization_server.token_url,
|
authorization_server.token_url,
|
||||||
)
|
)
|
||||||
|
self._name = credential.name
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
"""Name of the implementation."""
|
"""Name of the implementation."""
|
||||||
return self.client_id
|
return self._name or self.client_id
|
||||||
|
|
||||||
|
|
||||||
async def _async_provide_implementation(
|
async def _async_provide_implementation(
|
||||||
|
@ -13,13 +13,19 @@ import pytest
|
|||||||
from homeassistant import config_entries, data_entry_flow
|
from homeassistant import config_entries, data_entry_flow
|
||||||
from homeassistant.components.application_credentials import (
|
from homeassistant.components.application_credentials import (
|
||||||
CONF_AUTH_DOMAIN,
|
CONF_AUTH_DOMAIN,
|
||||||
|
DEFAULT_IMPORT_NAME,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
AuthImplementation,
|
AuthImplementation,
|
||||||
AuthorizationServer,
|
AuthorizationServer,
|
||||||
ClientCredential,
|
ClientCredential,
|
||||||
async_import_client_credential,
|
async_import_client_credential,
|
||||||
)
|
)
|
||||||
from homeassistant.const import CONF_CLIENT_ID, CONF_CLIENT_SECRET, CONF_DOMAIN
|
from homeassistant.const import (
|
||||||
|
CONF_CLIENT_ID,
|
||||||
|
CONF_CLIENT_SECRET,
|
||||||
|
CONF_DOMAIN,
|
||||||
|
CONF_NAME,
|
||||||
|
)
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.helpers import config_entry_oauth2_flow
|
from homeassistant.helpers import config_entry_oauth2_flow
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
@ -29,11 +35,13 @@ from tests.common import mock_platform
|
|||||||
CLIENT_ID = "some-client-id"
|
CLIENT_ID = "some-client-id"
|
||||||
CLIENT_SECRET = "some-client-secret"
|
CLIENT_SECRET = "some-client-secret"
|
||||||
DEVELOPER_CREDENTIAL = ClientCredential(CLIENT_ID, CLIENT_SECRET)
|
DEVELOPER_CREDENTIAL = ClientCredential(CLIENT_ID, CLIENT_SECRET)
|
||||||
|
NAMED_CREDENTIAL = ClientCredential(CLIENT_ID, CLIENT_SECRET, "Name")
|
||||||
ID = "fake_integration_some_client_id"
|
ID = "fake_integration_some_client_id"
|
||||||
AUTHORIZE_URL = "https://example.com/auth"
|
AUTHORIZE_URL = "https://example.com/auth"
|
||||||
TOKEN_URL = "https://example.com/oauth2/v4/token"
|
TOKEN_URL = "https://example.com/oauth2/v4/token"
|
||||||
REFRESH_TOKEN = "mock-refresh-token"
|
REFRESH_TOKEN = "mock-refresh-token"
|
||||||
ACCESS_TOKEN = "mock-access-token"
|
ACCESS_TOKEN = "mock-access-token"
|
||||||
|
NAME = "Name"
|
||||||
|
|
||||||
TEST_DOMAIN = "fake_integration"
|
TEST_DOMAIN = "fake_integration"
|
||||||
|
|
||||||
@ -118,6 +126,7 @@ class OAuthFixture:
|
|||||||
self.hass_client = hass_client
|
self.hass_client = hass_client
|
||||||
self.aioclient_mock = aioclient_mock
|
self.aioclient_mock = aioclient_mock
|
||||||
self.client_id = CLIENT_ID
|
self.client_id = CLIENT_ID
|
||||||
|
self.title = CLIENT_ID
|
||||||
|
|
||||||
async def complete_external_step(
|
async def complete_external_step(
|
||||||
self, result: data_entry_flow.FlowResult
|
self, result: data_entry_flow.FlowResult
|
||||||
@ -152,7 +161,7 @@ class OAuthFixture:
|
|||||||
|
|
||||||
result = await self.hass.config_entries.flow.async_configure(result["flow_id"])
|
result = await self.hass.config_entries.flow.async_configure(result["flow_id"])
|
||||||
assert result.get("type") == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
assert result.get("type") == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
|
||||||
assert result.get("title") == self.client_id
|
assert result.get("title") == self.title
|
||||||
assert "data" in result
|
assert "data" in result
|
||||||
assert "token" in result["data"]
|
assert "token" in result["data"]
|
||||||
return result
|
return result
|
||||||
@ -348,6 +357,7 @@ async def test_websocket_import_config(
|
|||||||
CONF_CLIENT_SECRET: CLIENT_SECRET,
|
CONF_CLIENT_SECRET: CLIENT_SECRET,
|
||||||
"id": ID,
|
"id": ID,
|
||||||
CONF_AUTH_DOMAIN: TEST_DOMAIN,
|
CONF_AUTH_DOMAIN: TEST_DOMAIN,
|
||||||
|
CONF_NAME: DEFAULT_IMPORT_NAME,
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -375,6 +385,29 @@ async def test_import_duplicate_credentials(
|
|||||||
CONF_CLIENT_SECRET: CLIENT_SECRET,
|
CONF_CLIENT_SECRET: CLIENT_SECRET,
|
||||||
"id": ID,
|
"id": ID,
|
||||||
CONF_AUTH_DOMAIN: TEST_DOMAIN,
|
CONF_AUTH_DOMAIN: TEST_DOMAIN,
|
||||||
|
CONF_NAME: DEFAULT_IMPORT_NAME,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("config_credential", [NAMED_CREDENTIAL])
|
||||||
|
async def test_import_named_credential(
|
||||||
|
ws_client: ClientFixture,
|
||||||
|
config_credential: ClientCredential,
|
||||||
|
import_config_credential: Any,
|
||||||
|
):
|
||||||
|
"""Test websocket list command for an imported credential."""
|
||||||
|
client = await ws_client()
|
||||||
|
|
||||||
|
# Imported creds returned from websocket
|
||||||
|
assert await client.cmd_result("list") == [
|
||||||
|
{
|
||||||
|
CONF_DOMAIN: TEST_DOMAIN,
|
||||||
|
CONF_CLIENT_ID: CLIENT_ID,
|
||||||
|
CONF_CLIENT_SECRET: CLIENT_SECRET,
|
||||||
|
"id": ID,
|
||||||
|
CONF_AUTH_DOMAIN: TEST_DOMAIN,
|
||||||
|
CONF_NAME: NAME,
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -487,6 +520,7 @@ async def test_config_flow_multiple_entries(
|
|||||||
)
|
)
|
||||||
assert result.get("type") == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP
|
assert result.get("type") == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP
|
||||||
oauth_fixture.client_id = CLIENT_ID + "2"
|
oauth_fixture.client_id = CLIENT_ID + "2"
|
||||||
|
oauth_fixture.title = CLIENT_ID + "2"
|
||||||
result = await oauth_fixture.complete_external_step(result)
|
result = await oauth_fixture.complete_external_step(result)
|
||||||
assert (
|
assert (
|
||||||
result["data"].get("auth_implementation") == "fake_integration_some_client_id2"
|
result["data"].get("auth_implementation") == "fake_integration_some_client_id2"
|
||||||
@ -532,6 +566,7 @@ async def test_config_flow_with_config_credential(
|
|||||||
TEST_DOMAIN, context={"source": config_entries.SOURCE_USER}
|
TEST_DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||||
)
|
)
|
||||||
assert result.get("type") == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP
|
assert result.get("type") == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP
|
||||||
|
oauth_fixture.title = DEFAULT_IMPORT_NAME
|
||||||
result = await oauth_fixture.complete_external_step(result)
|
result = await oauth_fixture.complete_external_step(result)
|
||||||
# Uses the imported auth domain for compatibility
|
# Uses the imported auth domain for compatibility
|
||||||
assert result["data"].get("auth_implementation") == TEST_DOMAIN
|
assert result["data"].get("auth_implementation") == TEST_DOMAIN
|
||||||
@ -653,6 +688,7 @@ async def test_platform_with_auth_implementation(
|
|||||||
TEST_DOMAIN, context={"source": config_entries.SOURCE_USER}
|
TEST_DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||||
)
|
)
|
||||||
assert result.get("type") == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP
|
assert result.get("type") == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP
|
||||||
|
oauth_fixture.title = DEFAULT_IMPORT_NAME
|
||||||
result = await oauth_fixture.complete_external_step(result)
|
result = await oauth_fixture.complete_external_step(result)
|
||||||
# Uses the imported auth domain for compatibility
|
# Uses the imported auth domain for compatibility
|
||||||
assert result["data"].get("auth_implementation") == TEST_DOMAIN
|
assert result["data"].get("auth_implementation") == TEST_DOMAIN
|
||||||
@ -667,3 +703,47 @@ async def test_websocket_integration_list(ws_client: ClientFixture):
|
|||||||
assert await client.cmd_result("config") == {
|
assert await client.cmd_result("config") == {
|
||||||
"domains": ["example1", "example2"]
|
"domains": ["example1", "example2"]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_name(
|
||||||
|
hass: HomeAssistant, ws_client: ClientFixture, oauth_fixture: OAuthFixture
|
||||||
|
):
|
||||||
|
"""Test a credential with a name set."""
|
||||||
|
client = await ws_client()
|
||||||
|
result = await client.cmd_result(
|
||||||
|
"create",
|
||||||
|
{
|
||||||
|
CONF_DOMAIN: TEST_DOMAIN,
|
||||||
|
CONF_CLIENT_ID: CLIENT_ID,
|
||||||
|
CONF_CLIENT_SECRET: CLIENT_SECRET,
|
||||||
|
CONF_NAME: NAME,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert result == {
|
||||||
|
CONF_DOMAIN: TEST_DOMAIN,
|
||||||
|
CONF_CLIENT_ID: CLIENT_ID,
|
||||||
|
CONF_CLIENT_SECRET: CLIENT_SECRET,
|
||||||
|
CONF_NAME: NAME,
|
||||||
|
"id": ID,
|
||||||
|
}
|
||||||
|
|
||||||
|
result = await client.cmd_result("list")
|
||||||
|
assert result == [
|
||||||
|
{
|
||||||
|
CONF_DOMAIN: TEST_DOMAIN,
|
||||||
|
CONF_CLIENT_ID: CLIENT_ID,
|
||||||
|
CONF_CLIENT_SECRET: CLIENT_SECRET,
|
||||||
|
CONF_NAME: NAME,
|
||||||
|
"id": ID,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
result = await hass.config_entries.flow.async_init(
|
||||||
|
TEST_DOMAIN, context={"source": config_entries.SOURCE_USER}
|
||||||
|
)
|
||||||
|
assert result.get("type") == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP
|
||||||
|
oauth_fixture.title = NAME
|
||||||
|
result = await oauth_fixture.complete_external_step(result)
|
||||||
|
assert (
|
||||||
|
result["data"].get("auth_implementation") == "fake_integration_some_client_id"
|
||||||
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user