diff --git a/homeassistant/components/application_credentials/__init__.py b/homeassistant/components/application_credentials/__init__.py index 9117a91c33d..1a128c5c378 100644 --- a/homeassistant/components/application_credentials/__init__.py +++ b/homeassistant/components/application_credentials/__init__.py @@ -15,7 +15,13 @@ import voluptuous as vol from homeassistant.components import websocket_api from homeassistant.components.websocket_api.connection import ActiveConnection -from homeassistant.const import CONF_CLIENT_ID, CONF_CLIENT_SECRET, CONF_DOMAIN, CONF_ID +from homeassistant.const import ( + CONF_CLIENT_ID, + CONF_CLIENT_SECRET, + CONF_DOMAIN, + CONF_ID, + CONF_NAME, +) from homeassistant.core import HomeAssistant, callback from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import collection, config_entry_oauth2_flow @@ -39,12 +45,14 @@ STORAGE_KEY = DOMAIN STORAGE_VERSION = 1 DATA_STORAGE = "storage" CONF_AUTH_DOMAIN = "auth_domain" +DEFAULT_IMPORT_NAME = "Import from configuration.yaml" CREATE_FIELDS = { vol.Required(CONF_DOMAIN): cv.string, vol.Required(CONF_CLIENT_ID): cv.string, vol.Required(CONF_CLIENT_SECRET): cv.string, vol.Optional(CONF_AUTH_DOMAIN): cv.string, + vol.Optional(CONF_NAME): cv.string, } UPDATE_FIELDS: dict = {} # Not supported @@ -55,6 +63,7 @@ class ClientCredential: client_id: str client_secret: str + name: str | None = None @dataclass @@ -122,7 +131,9 @@ class ApplicationCredentialsStorageCollection(collection.StorageCollection): item[CONF_AUTH_DOMAIN] if CONF_AUTH_DOMAIN in item else item[CONF_ID] ) credentials[auth_domain] = ClientCredential( - item[CONF_CLIENT_ID], item[CONF_CLIENT_SECRET] + client_id=item[CONF_CLIENT_ID], + client_secret=item[CONF_CLIENT_SECRET], + name=item.get(CONF_NAME), ) return credentials @@ -169,6 +180,7 @@ async def async_import_client_credential( CONF_CLIENT_SECRET: credential.client_secret, CONF_AUTH_DOMAIN: auth_domain if auth_domain else domain, } + item[CONF_NAME] = credential.name if credential.name else DEFAULT_IMPORT_NAME await storage_collection.async_import_item(item) @@ -191,11 +203,12 @@ class AuthImplementation(config_entry_oauth2_flow.LocalOAuth2Implementation): authorization_server.authorize_url, authorization_server.token_url, ) + self._name = credential.name @property def name(self) -> str: """Name of the implementation.""" - return self.client_id + return self._name or self.client_id async def _async_provide_implementation( diff --git a/tests/components/application_credentials/test_init.py b/tests/components/application_credentials/test_init.py index b62f8a0139c..b89a60f42e4 100644 --- a/tests/components/application_credentials/test_init.py +++ b/tests/components/application_credentials/test_init.py @@ -13,13 +13,19 @@ import pytest from homeassistant import config_entries, data_entry_flow from homeassistant.components.application_credentials import ( CONF_AUTH_DOMAIN, + DEFAULT_IMPORT_NAME, DOMAIN, AuthImplementation, AuthorizationServer, ClientCredential, async_import_client_credential, ) -from homeassistant.const import CONF_CLIENT_ID, CONF_CLIENT_SECRET, CONF_DOMAIN +from homeassistant.const import ( + CONF_CLIENT_ID, + CONF_CLIENT_SECRET, + CONF_DOMAIN, + CONF_NAME, +) from homeassistant.core import HomeAssistant from homeassistant.helpers import config_entry_oauth2_flow from homeassistant.setup import async_setup_component @@ -29,11 +35,13 @@ from tests.common import mock_platform CLIENT_ID = "some-client-id" CLIENT_SECRET = "some-client-secret" DEVELOPER_CREDENTIAL = ClientCredential(CLIENT_ID, CLIENT_SECRET) +NAMED_CREDENTIAL = ClientCredential(CLIENT_ID, CLIENT_SECRET, "Name") ID = "fake_integration_some_client_id" AUTHORIZE_URL = "https://example.com/auth" TOKEN_URL = "https://example.com/oauth2/v4/token" REFRESH_TOKEN = "mock-refresh-token" ACCESS_TOKEN = "mock-access-token" +NAME = "Name" TEST_DOMAIN = "fake_integration" @@ -118,6 +126,7 @@ class OAuthFixture: self.hass_client = hass_client self.aioclient_mock = aioclient_mock self.client_id = CLIENT_ID + self.title = CLIENT_ID async def complete_external_step( self, result: data_entry_flow.FlowResult @@ -152,7 +161,7 @@ class OAuthFixture: result = await self.hass.config_entries.flow.async_configure(result["flow_id"]) assert result.get("type") == data_entry_flow.RESULT_TYPE_CREATE_ENTRY - assert result.get("title") == self.client_id + assert result.get("title") == self.title assert "data" in result assert "token" in result["data"] return result @@ -348,6 +357,7 @@ async def test_websocket_import_config( CONF_CLIENT_SECRET: CLIENT_SECRET, "id": ID, CONF_AUTH_DOMAIN: TEST_DOMAIN, + CONF_NAME: DEFAULT_IMPORT_NAME, } ] @@ -375,6 +385,29 @@ async def test_import_duplicate_credentials( CONF_CLIENT_SECRET: CLIENT_SECRET, "id": ID, CONF_AUTH_DOMAIN: TEST_DOMAIN, + CONF_NAME: DEFAULT_IMPORT_NAME, + } + ] + + +@pytest.mark.parametrize("config_credential", [NAMED_CREDENTIAL]) +async def test_import_named_credential( + ws_client: ClientFixture, + config_credential: ClientCredential, + import_config_credential: Any, +): + """Test websocket list command for an imported credential.""" + client = await ws_client() + + # Imported creds returned from websocket + assert await client.cmd_result("list") == [ + { + CONF_DOMAIN: TEST_DOMAIN, + CONF_CLIENT_ID: CLIENT_ID, + CONF_CLIENT_SECRET: CLIENT_SECRET, + "id": ID, + CONF_AUTH_DOMAIN: TEST_DOMAIN, + CONF_NAME: NAME, } ] @@ -487,6 +520,7 @@ async def test_config_flow_multiple_entries( ) assert result.get("type") == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP oauth_fixture.client_id = CLIENT_ID + "2" + oauth_fixture.title = CLIENT_ID + "2" result = await oauth_fixture.complete_external_step(result) assert ( result["data"].get("auth_implementation") == "fake_integration_some_client_id2" @@ -532,6 +566,7 @@ async def test_config_flow_with_config_credential( TEST_DOMAIN, context={"source": config_entries.SOURCE_USER} ) assert result.get("type") == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP + oauth_fixture.title = DEFAULT_IMPORT_NAME result = await oauth_fixture.complete_external_step(result) # Uses the imported auth domain for compatibility assert result["data"].get("auth_implementation") == TEST_DOMAIN @@ -653,6 +688,7 @@ async def test_platform_with_auth_implementation( TEST_DOMAIN, context={"source": config_entries.SOURCE_USER} ) assert result.get("type") == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP + oauth_fixture.title = DEFAULT_IMPORT_NAME result = await oauth_fixture.complete_external_step(result) # Uses the imported auth domain for compatibility assert result["data"].get("auth_implementation") == TEST_DOMAIN @@ -667,3 +703,47 @@ async def test_websocket_integration_list(ws_client: ClientFixture): assert await client.cmd_result("config") == { "domains": ["example1", "example2"] } + + +async def test_name( + hass: HomeAssistant, ws_client: ClientFixture, oauth_fixture: OAuthFixture +): + """Test a credential with a name set.""" + client = await ws_client() + result = await client.cmd_result( + "create", + { + CONF_DOMAIN: TEST_DOMAIN, + CONF_CLIENT_ID: CLIENT_ID, + CONF_CLIENT_SECRET: CLIENT_SECRET, + CONF_NAME: NAME, + }, + ) + assert result == { + CONF_DOMAIN: TEST_DOMAIN, + CONF_CLIENT_ID: CLIENT_ID, + CONF_CLIENT_SECRET: CLIENT_SECRET, + CONF_NAME: NAME, + "id": ID, + } + + result = await client.cmd_result("list") + assert result == [ + { + CONF_DOMAIN: TEST_DOMAIN, + CONF_CLIENT_ID: CLIENT_ID, + CONF_CLIENT_SECRET: CLIENT_SECRET, + CONF_NAME: NAME, + "id": ID, + } + ] + + result = await hass.config_entries.flow.async_init( + TEST_DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + assert result.get("type") == data_entry_flow.RESULT_TYPE_EXTERNAL_STEP + oauth_fixture.title = NAME + result = await oauth_fixture.complete_external_step(result) + assert ( + result["data"].get("auth_implementation") == "fake_integration_some_client_id" + )