mirror of
https://github.com/home-assistant/core.git
synced 2025-07-17 10:17:09 +00:00
Google Assistant SDK: Always enable conversation agent and support multiple languages (#93201)
* Enable agent and support multiple languages * fix test
This commit is contained in:
parent
1dcaec4ece
commit
17ceacd083
@ -18,18 +18,11 @@ from homeassistant.helpers.config_entry_oauth2_flow import (
|
|||||||
)
|
)
|
||||||
from homeassistant.helpers.typing import ConfigType
|
from homeassistant.helpers.typing import ConfigType
|
||||||
|
|
||||||
from .const import (
|
from .const import DATA_MEM_STORAGE, DATA_SESSION, DOMAIN, SUPPORTED_LANGUAGE_CODES
|
||||||
CONF_ENABLE_CONVERSATION_AGENT,
|
|
||||||
CONF_LANGUAGE_CODE,
|
|
||||||
DATA_MEM_STORAGE,
|
|
||||||
DATA_SESSION,
|
|
||||||
DOMAIN,
|
|
||||||
)
|
|
||||||
from .helpers import (
|
from .helpers import (
|
||||||
GoogleAssistantSDKAudioView,
|
GoogleAssistantSDKAudioView,
|
||||||
InMemoryStorage,
|
InMemoryStorage,
|
||||||
async_send_text_commands,
|
async_send_text_commands,
|
||||||
default_language_code,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
SERVICE_SEND_TEXT_COMMAND = "send_text_command"
|
SERVICE_SEND_TEXT_COMMAND = "send_text_command"
|
||||||
@ -82,8 +75,8 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||||||
|
|
||||||
await async_setup_service(hass)
|
await async_setup_service(hass)
|
||||||
|
|
||||||
entry.async_on_unload(entry.add_update_listener(update_listener))
|
agent = GoogleAssistantConversationAgent(hass, entry)
|
||||||
await update_listener(hass, entry)
|
conversation.async_set_agent(hass, entry, agent)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -100,8 +93,7 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||||||
for service_name in hass.services.async_services()[DOMAIN]:
|
for service_name in hass.services.async_services()[DOMAIN]:
|
||||||
hass.services.async_remove(DOMAIN, service_name)
|
hass.services.async_remove(DOMAIN, service_name)
|
||||||
|
|
||||||
if entry.options.get(CONF_ENABLE_CONVERSATION_AGENT, False):
|
conversation.async_unset_agent(hass, entry)
|
||||||
conversation.async_unset_agent(hass, entry)
|
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
@ -125,15 +117,6 @@ async def async_setup_service(hass: HomeAssistant) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def update_listener(hass, entry):
|
|
||||||
"""Handle options update."""
|
|
||||||
if entry.options.get(CONF_ENABLE_CONVERSATION_AGENT, False):
|
|
||||||
agent = GoogleAssistantConversationAgent(hass, entry)
|
|
||||||
conversation.async_set_agent(hass, entry, agent)
|
|
||||||
else:
|
|
||||||
conversation.async_unset_agent(hass, entry)
|
|
||||||
|
|
||||||
|
|
||||||
class GoogleAssistantConversationAgent(conversation.AbstractConversationAgent):
|
class GoogleAssistantConversationAgent(conversation.AbstractConversationAgent):
|
||||||
"""Google Assistant SDK conversation agent."""
|
"""Google Assistant SDK conversation agent."""
|
||||||
|
|
||||||
@ -143,6 +126,7 @@ class GoogleAssistantConversationAgent(conversation.AbstractConversationAgent):
|
|||||||
self.entry = entry
|
self.entry = entry
|
||||||
self.assistant: TextAssistant | None = None
|
self.assistant: TextAssistant | None = None
|
||||||
self.session: OAuth2Session | None = None
|
self.session: OAuth2Session | None = None
|
||||||
|
self.language: str | None = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def attribution(self):
|
def attribution(self):
|
||||||
@ -155,10 +139,7 @@ class GoogleAssistantConversationAgent(conversation.AbstractConversationAgent):
|
|||||||
@property
|
@property
|
||||||
def supported_languages(self) -> list[str]:
|
def supported_languages(self) -> list[str]:
|
||||||
"""Return a list of supported languages."""
|
"""Return a list of supported languages."""
|
||||||
language_code = self.entry.options.get(
|
return SUPPORTED_LANGUAGE_CODES
|
||||||
CONF_LANGUAGE_CODE, default_language_code(self.hass)
|
|
||||||
)
|
|
||||||
return [language_code]
|
|
||||||
|
|
||||||
async def async_process(
|
async def async_process(
|
||||||
self, user_input: conversation.ConversationInput
|
self, user_input: conversation.ConversationInput
|
||||||
@ -172,12 +153,10 @@ class GoogleAssistantConversationAgent(conversation.AbstractConversationAgent):
|
|||||||
if not session.valid_token:
|
if not session.valid_token:
|
||||||
await session.async_ensure_token_valid()
|
await session.async_ensure_token_valid()
|
||||||
self.assistant = None
|
self.assistant = None
|
||||||
if not self.assistant:
|
if not self.assistant or user_input.language != self.language:
|
||||||
credentials = Credentials(session.token[CONF_ACCESS_TOKEN])
|
credentials = Credentials(session.token[CONF_ACCESS_TOKEN])
|
||||||
language_code = self.entry.options.get(
|
self.language = user_input.language
|
||||||
CONF_LANGUAGE_CODE, default_language_code(self.hass)
|
self.assistant = TextAssistant(credentials, self.language)
|
||||||
)
|
|
||||||
self.assistant = TextAssistant(credentials, language_code)
|
|
||||||
|
|
||||||
resp = self.assistant.assist(user_input.text)
|
resp = self.assistant.assist(user_input.text)
|
||||||
text_response = resp[0] or "<empty response>"
|
text_response = resp[0] or "<empty response>"
|
||||||
|
@ -13,13 +13,7 @@ from homeassistant.core import callback
|
|||||||
from homeassistant.data_entry_flow import FlowResult
|
from homeassistant.data_entry_flow import FlowResult
|
||||||
from homeassistant.helpers import config_entry_oauth2_flow
|
from homeassistant.helpers import config_entry_oauth2_flow
|
||||||
|
|
||||||
from .const import (
|
from .const import CONF_LANGUAGE_CODE, DEFAULT_NAME, DOMAIN, SUPPORTED_LANGUAGE_CODES
|
||||||
CONF_ENABLE_CONVERSATION_AGENT,
|
|
||||||
CONF_LANGUAGE_CODE,
|
|
||||||
DEFAULT_NAME,
|
|
||||||
DOMAIN,
|
|
||||||
SUPPORTED_LANGUAGE_CODES,
|
|
||||||
)
|
|
||||||
from .helpers import default_language_code
|
from .helpers import default_language_code
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
@ -114,12 +108,6 @@ class OptionsFlowHandler(config_entries.OptionsFlow):
|
|||||||
CONF_LANGUAGE_CODE,
|
CONF_LANGUAGE_CODE,
|
||||||
default=self.config_entry.options.get(CONF_LANGUAGE_CODE),
|
default=self.config_entry.options.get(CONF_LANGUAGE_CODE),
|
||||||
): vol.In(SUPPORTED_LANGUAGE_CODES),
|
): vol.In(SUPPORTED_LANGUAGE_CODES),
|
||||||
vol.Required(
|
|
||||||
CONF_ENABLE_CONVERSATION_AGENT,
|
|
||||||
default=self.config_entry.options.get(
|
|
||||||
CONF_ENABLE_CONVERSATION_AGENT
|
|
||||||
),
|
|
||||||
): bool,
|
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@ -5,7 +5,6 @@ DOMAIN: Final = "google_assistant_sdk"
|
|||||||
|
|
||||||
DEFAULT_NAME: Final = "Google Assistant SDK"
|
DEFAULT_NAME: Final = "Google Assistant SDK"
|
||||||
|
|
||||||
CONF_ENABLE_CONVERSATION_AGENT: Final = "enable_conversation_agent"
|
|
||||||
CONF_LANGUAGE_CODE: Final = "language_code"
|
CONF_LANGUAGE_CODE: Final = "language_code"
|
||||||
|
|
||||||
DATA_MEM_STORAGE: Final = "mem_storage"
|
DATA_MEM_STORAGE: Final = "mem_storage"
|
||||||
|
@ -31,10 +31,8 @@
|
|||||||
"step": {
|
"step": {
|
||||||
"init": {
|
"init": {
|
||||||
"data": {
|
"data": {
|
||||||
"enable_conversation_agent": "Enable the conversation agent",
|
|
||||||
"language_code": "Language code"
|
"language_code": "Language code"
|
||||||
},
|
}
|
||||||
"description": "Set language for interactions with Google Assistant and whether you want to enable the conversation agent."
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -223,65 +223,39 @@ async def test_options_flow(
|
|||||||
assert result["type"] == "form"
|
assert result["type"] == "form"
|
||||||
assert result["step_id"] == "init"
|
assert result["step_id"] == "init"
|
||||||
data_schema = result["data_schema"].schema
|
data_schema = result["data_schema"].schema
|
||||||
assert set(data_schema) == {"enable_conversation_agent", "language_code"}
|
assert set(data_schema) == {"language_code"}
|
||||||
|
|
||||||
result = await hass.config_entries.options.async_configure(
|
result = await hass.config_entries.options.async_configure(
|
||||||
result["flow_id"],
|
result["flow_id"],
|
||||||
user_input={"enable_conversation_agent": False, "language_code": "es-ES"},
|
user_input={"language_code": "es-ES"},
|
||||||
)
|
)
|
||||||
assert result["type"] == "create_entry"
|
assert result["type"] == "create_entry"
|
||||||
assert config_entry.options == {
|
assert config_entry.options == {"language_code": "es-ES"}
|
||||||
"enable_conversation_agent": False,
|
|
||||||
"language_code": "es-ES",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Retrigger options flow, not change language
|
# Retrigger options flow, not change language
|
||||||
result = await hass.config_entries.options.async_init(config_entry.entry_id)
|
result = await hass.config_entries.options.async_init(config_entry.entry_id)
|
||||||
assert result["type"] == "form"
|
assert result["type"] == "form"
|
||||||
assert result["step_id"] == "init"
|
assert result["step_id"] == "init"
|
||||||
data_schema = result["data_schema"].schema
|
data_schema = result["data_schema"].schema
|
||||||
assert set(data_schema) == {"enable_conversation_agent", "language_code"}
|
assert set(data_schema) == {"language_code"}
|
||||||
|
|
||||||
result = await hass.config_entries.options.async_configure(
|
result = await hass.config_entries.options.async_configure(
|
||||||
result["flow_id"],
|
result["flow_id"],
|
||||||
user_input={"enable_conversation_agent": False, "language_code": "es-ES"},
|
user_input={"language_code": "es-ES"},
|
||||||
)
|
)
|
||||||
assert result["type"] == "create_entry"
|
assert result["type"] == "create_entry"
|
||||||
assert config_entry.options == {
|
assert config_entry.options == {"language_code": "es-ES"}
|
||||||
"enable_conversation_agent": False,
|
|
||||||
"language_code": "es-ES",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Retrigger options flow, change language
|
# Retrigger options flow, change language
|
||||||
result = await hass.config_entries.options.async_init(config_entry.entry_id)
|
result = await hass.config_entries.options.async_init(config_entry.entry_id)
|
||||||
assert result["type"] == "form"
|
assert result["type"] == "form"
|
||||||
assert result["step_id"] == "init"
|
assert result["step_id"] == "init"
|
||||||
data_schema = result["data_schema"].schema
|
data_schema = result["data_schema"].schema
|
||||||
assert set(data_schema) == {"enable_conversation_agent", "language_code"}
|
assert set(data_schema) == {"language_code"}
|
||||||
|
|
||||||
result = await hass.config_entries.options.async_configure(
|
result = await hass.config_entries.options.async_configure(
|
||||||
result["flow_id"],
|
result["flow_id"],
|
||||||
user_input={"enable_conversation_agent": False, "language_code": "en-US"},
|
user_input={"language_code": "en-US"},
|
||||||
)
|
)
|
||||||
assert result["type"] == "create_entry"
|
assert result["type"] == "create_entry"
|
||||||
assert config_entry.options == {
|
assert config_entry.options == {"language_code": "en-US"}
|
||||||
"enable_conversation_agent": False,
|
|
||||||
"language_code": "en-US",
|
|
||||||
}
|
|
||||||
|
|
||||||
# Retrigger options flow, enable conversation agent
|
|
||||||
result = await hass.config_entries.options.async_init(config_entry.entry_id)
|
|
||||||
assert result["type"] == "form"
|
|
||||||
assert result["step_id"] == "init"
|
|
||||||
data_schema = result["data_schema"].schema
|
|
||||||
assert set(data_schema) == {"enable_conversation_agent", "language_code"}
|
|
||||||
|
|
||||||
result = await hass.config_entries.options.async_configure(
|
|
||||||
result["flow_id"],
|
|
||||||
user_input={"enable_conversation_agent": True, "language_code": "en-US"},
|
|
||||||
)
|
|
||||||
assert result["type"] == "create_entry"
|
|
||||||
assert config_entry.options == {
|
|
||||||
"enable_conversation_agent": True,
|
|
||||||
"language_code": "en-US",
|
|
||||||
}
|
|
||||||
|
@ -9,8 +9,9 @@ import pytest
|
|||||||
|
|
||||||
from homeassistant.components import conversation
|
from homeassistant.components import conversation
|
||||||
from homeassistant.components.google_assistant_sdk import DOMAIN
|
from homeassistant.components.google_assistant_sdk import DOMAIN
|
||||||
|
from homeassistant.components.google_assistant_sdk.const import SUPPORTED_LANGUAGE_CODES
|
||||||
from homeassistant.config_entries import ConfigEntryState
|
from homeassistant.config_entries import ConfigEntryState
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import Context, HomeAssistant
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
from homeassistant.util.dt import utcnow
|
from homeassistant.util.dt import utcnow
|
||||||
|
|
||||||
@ -29,13 +30,9 @@ async def fetch_api_url(hass_client, url):
|
|||||||
return response.status, contents
|
return response.status, contents
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"enable_conversation_agent", [False, True], ids=["", "enable_conversation_agent"]
|
|
||||||
)
|
|
||||||
async def test_setup_success(
|
async def test_setup_success(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
setup_integration: ComponentSetup,
|
setup_integration: ComponentSetup,
|
||||||
enable_conversation_agent: bool,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test successful setup and unload."""
|
"""Test successful setup and unload."""
|
||||||
await setup_integration()
|
await setup_integration()
|
||||||
@ -44,12 +41,6 @@ async def test_setup_success(
|
|||||||
assert len(entries) == 1
|
assert len(entries) == 1
|
||||||
assert entries[0].state is ConfigEntryState.LOADED
|
assert entries[0].state is ConfigEntryState.LOADED
|
||||||
|
|
||||||
if enable_conversation_agent:
|
|
||||||
hass.config_entries.async_update_entry(
|
|
||||||
entries[0], options={"enable_conversation_agent": True}
|
|
||||||
)
|
|
||||||
await hass.async_block_till_done()
|
|
||||||
|
|
||||||
await hass.config_entries.async_unload(entries[0].entry_id)
|
await hass.config_entries.async_unload(entries[0].entry_id)
|
||||||
await hass.async_block_till_done()
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
@ -333,30 +324,21 @@ async def test_conversation_agent(
|
|||||||
assert len(entries) == 1
|
assert len(entries) == 1
|
||||||
entry = entries[0]
|
entry = entries[0]
|
||||||
assert entry.state is ConfigEntryState.LOADED
|
assert entry.state is ConfigEntryState.LOADED
|
||||||
hass.config_entries.async_update_entry(
|
|
||||||
entry, options={"enable_conversation_agent": True}
|
|
||||||
)
|
|
||||||
await hass.async_block_till_done()
|
|
||||||
|
|
||||||
agent = await conversation._get_agent_manager(hass).async_get_agent(entry.entry_id)
|
agent = await conversation._get_agent_manager(hass).async_get_agent(entry.entry_id)
|
||||||
assert agent.supported_languages == ["en-US"]
|
assert agent.attribution.keys() == {"name", "url"}
|
||||||
|
assert agent.supported_languages == SUPPORTED_LANGUAGE_CODES
|
||||||
|
|
||||||
text1 = "tell me a joke"
|
text1 = "tell me a joke"
|
||||||
text2 = "tell me another one"
|
text2 = "tell me another one"
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.google_assistant_sdk.TextAssistant"
|
"homeassistant.components.google_assistant_sdk.TextAssistant"
|
||||||
) as mock_text_assistant:
|
) as mock_text_assistant:
|
||||||
await hass.services.async_call(
|
await conversation.async_converse(
|
||||||
"conversation",
|
hass, text1, None, Context(), "en-US", config_entry.entry_id
|
||||||
"process",
|
|
||||||
{"text": text1, "agent_id": config_entry.entry_id},
|
|
||||||
blocking=True,
|
|
||||||
)
|
)
|
||||||
await hass.services.async_call(
|
await conversation.async_converse(
|
||||||
"conversation",
|
hass, text2, None, Context(), "en-US", config_entry.entry_id
|
||||||
"process",
|
|
||||||
{"text": text2, "agent_id": config_entry.entry_id},
|
|
||||||
blocking=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert constructor is called only once since it's reused across requests
|
# Assert constructor is called only once since it's reused across requests
|
||||||
@ -381,21 +363,14 @@ async def test_conversation_agent_refresh_token(
|
|||||||
assert len(entries) == 1
|
assert len(entries) == 1
|
||||||
entry = entries[0]
|
entry = entries[0]
|
||||||
assert entry.state is ConfigEntryState.LOADED
|
assert entry.state is ConfigEntryState.LOADED
|
||||||
hass.config_entries.async_update_entry(
|
|
||||||
entry, options={"enable_conversation_agent": True}
|
|
||||||
)
|
|
||||||
await hass.async_block_till_done()
|
|
||||||
|
|
||||||
text1 = "tell me a joke"
|
text1 = "tell me a joke"
|
||||||
text2 = "tell me another one"
|
text2 = "tell me another one"
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.google_assistant_sdk.TextAssistant"
|
"homeassistant.components.google_assistant_sdk.TextAssistant"
|
||||||
) as mock_text_assistant:
|
) as mock_text_assistant:
|
||||||
await hass.services.async_call(
|
await conversation.async_converse(
|
||||||
"conversation",
|
hass, text1, None, Context(), "en-US", config_entry.entry_id
|
||||||
"process",
|
|
||||||
{"text": text1, "agent_id": config_entry.entry_id},
|
|
||||||
blocking=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Expire the token between requests
|
# Expire the token between requests
|
||||||
@ -411,11 +386,8 @@ async def test_conversation_agent_refresh_token(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
await hass.services.async_call(
|
await conversation.async_converse(
|
||||||
"conversation",
|
hass, text2, None, Context(), "en-US", config_entry.entry_id
|
||||||
"process",
|
|
||||||
{"text": text2, "agent_id": config_entry.entry_id},
|
|
||||||
blocking=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Assert constructor is called twice since the token was expired
|
# Assert constructor is called twice since the token was expired
|
||||||
@ -426,3 +398,38 @@ async def test_conversation_agent_refresh_token(
|
|||||||
)
|
)
|
||||||
mock_text_assistant.assert_has_calls([call().assist(text1)])
|
mock_text_assistant.assert_has_calls([call().assist(text1)])
|
||||||
mock_text_assistant.assert_has_calls([call().assist(text2)])
|
mock_text_assistant.assert_has_calls([call().assist(text2)])
|
||||||
|
|
||||||
|
|
||||||
|
async def test_conversation_agent_language_changed(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
config_entry: MockConfigEntry,
|
||||||
|
setup_integration: ComponentSetup,
|
||||||
|
) -> None:
|
||||||
|
"""Test GoogleAssistantConversationAgent when language is changed."""
|
||||||
|
await setup_integration()
|
||||||
|
|
||||||
|
assert await async_setup_component(hass, "conversation", {})
|
||||||
|
|
||||||
|
entries = hass.config_entries.async_entries(DOMAIN)
|
||||||
|
assert len(entries) == 1
|
||||||
|
entry = entries[0]
|
||||||
|
assert entry.state is ConfigEntryState.LOADED
|
||||||
|
|
||||||
|
text1 = "tell me a joke"
|
||||||
|
text2 = "cuéntame un chiste"
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.google_assistant_sdk.TextAssistant"
|
||||||
|
) as mock_text_assistant:
|
||||||
|
await conversation.async_converse(
|
||||||
|
hass, text1, None, Context(), "en-US", config_entry.entry_id
|
||||||
|
)
|
||||||
|
await conversation.async_converse(
|
||||||
|
hass, text2, None, Context(), "es-ES", config_entry.entry_id
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert constructor is called twice since the language was changed
|
||||||
|
assert mock_text_assistant.call_count == 2
|
||||||
|
mock_text_assistant.assert_has_calls([call(ExpectedCredentials(), "en-US")])
|
||||||
|
mock_text_assistant.assert_has_calls([call(ExpectedCredentials(), "es-ES")])
|
||||||
|
mock_text_assistant.assert_has_calls([call().assist(text1)])
|
||||||
|
mock_text_assistant.assert_has_calls([call().assist(text2)])
|
||||||
|
Loading…
x
Reference in New Issue
Block a user