mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 05:07:41 +00:00
Google Assistant SDK conversation agent (#85499)
* Google Assistant SDK conversation agent * refresh token * fix session * Add tests * Add option to enable conversation agent
This commit is contained in:
parent
f2df72e014
commit
e24989b446
@ -2,21 +2,24 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
|
from gassist_text import TextAssistant
|
||||||
|
from google.oauth2.credentials import Credentials
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
|
from homeassistant.components import conversation
|
||||||
from homeassistant.config_entries import ConfigEntry, ConfigEntryState
|
from homeassistant.config_entries import ConfigEntry, ConfigEntryState
|
||||||
from homeassistant.const import CONF_NAME, Platform
|
from homeassistant.const import CONF_ACCESS_TOKEN, CONF_NAME, Platform
|
||||||
from homeassistant.core import HomeAssistant, ServiceCall
|
from homeassistant.core import Context, HomeAssistant, ServiceCall
|
||||||
from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady
|
from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady
|
||||||
from homeassistant.helpers import discovery
|
from homeassistant.helpers import discovery, intent
|
||||||
from homeassistant.helpers.config_entry_oauth2_flow import (
|
from homeassistant.helpers.config_entry_oauth2_flow import (
|
||||||
OAuth2Session,
|
OAuth2Session,
|
||||||
async_get_config_entry_implementation,
|
async_get_config_entry_implementation,
|
||||||
)
|
)
|
||||||
from homeassistant.helpers.typing import ConfigType
|
from homeassistant.helpers.typing import ConfigType
|
||||||
|
|
||||||
from .const import DOMAIN
|
from .const import CONF_ENABLE_CONVERSATION_AGENT, CONF_LANGUAGE_CODE, DOMAIN
|
||||||
from .helpers import async_send_text_commands
|
from .helpers import async_send_text_commands, default_language_code
|
||||||
|
|
||||||
SERVICE_SEND_TEXT_COMMAND = "send_text_command"
|
SERVICE_SEND_TEXT_COMMAND = "send_text_command"
|
||||||
SERVICE_SEND_TEXT_COMMAND_FIELD_COMMAND = "command"
|
SERVICE_SEND_TEXT_COMMAND_FIELD_COMMAND = "command"
|
||||||
@ -58,6 +61,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
|||||||
|
|
||||||
await async_setup_service(hass)
|
await async_setup_service(hass)
|
||||||
|
|
||||||
|
entry.async_on_unload(entry.add_update_listener(update_listener))
|
||||||
|
await update_listener(hass, entry)
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
@ -90,3 +96,64 @@ async def async_setup_service(hass: HomeAssistant) -> None:
|
|||||||
send_text_command,
|
send_text_command,
|
||||||
schema=SERVICE_SEND_TEXT_COMMAND_SCHEMA,
|
schema=SERVICE_SEND_TEXT_COMMAND_SCHEMA,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def update_listener(hass, entry):
|
||||||
|
"""Handle options update."""
|
||||||
|
if entry.options.get(CONF_ENABLE_CONVERSATION_AGENT, False):
|
||||||
|
agent = GoogleAssistantConversationAgent(hass, entry)
|
||||||
|
conversation.async_set_agent(hass, agent)
|
||||||
|
else:
|
||||||
|
conversation.async_set_agent(hass, None)
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleAssistantConversationAgent(conversation.AbstractConversationAgent):
|
||||||
|
"""Google Assistant SDK conversation agent."""
|
||||||
|
|
||||||
|
def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None:
|
||||||
|
"""Initialize the agent."""
|
||||||
|
self.hass = hass
|
||||||
|
self.entry = entry
|
||||||
|
self.assistant: TextAssistant | None = None
|
||||||
|
self.session: OAuth2Session | None = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def attribution(self):
|
||||||
|
"""Return the attribution."""
|
||||||
|
return {
|
||||||
|
"name": "Powered by Google Assistant SDK",
|
||||||
|
"url": "https://www.home-assistant.io/integrations/google_assistant_sdk/",
|
||||||
|
}
|
||||||
|
|
||||||
|
async def async_process(
|
||||||
|
self,
|
||||||
|
text: str,
|
||||||
|
context: Context,
|
||||||
|
conversation_id: str | None = None,
|
||||||
|
language: str | None = None,
|
||||||
|
) -> conversation.ConversationResult | None:
|
||||||
|
"""Process a sentence."""
|
||||||
|
if self.session:
|
||||||
|
session = self.session
|
||||||
|
else:
|
||||||
|
session = self.hass.data[DOMAIN].get(self.entry.entry_id)
|
||||||
|
self.session = session
|
||||||
|
if not session.valid_token:
|
||||||
|
await session.async_ensure_token_valid()
|
||||||
|
self.assistant = None
|
||||||
|
if not self.assistant:
|
||||||
|
credentials = Credentials(session.token[CONF_ACCESS_TOKEN])
|
||||||
|
language_code = self.entry.options.get(
|
||||||
|
CONF_LANGUAGE_CODE, default_language_code(self.hass)
|
||||||
|
)
|
||||||
|
self.assistant = TextAssistant(credentials, language_code)
|
||||||
|
|
||||||
|
resp = self.assistant.assist(text)
|
||||||
|
text_response = resp[0]
|
||||||
|
|
||||||
|
language = language or self.hass.config.language
|
||||||
|
intent_response = intent.IntentResponse(language=language)
|
||||||
|
intent_response.async_set_speech(text_response)
|
||||||
|
return conversation.ConversationResult(
|
||||||
|
response=intent_response, conversation_id=conversation_id
|
||||||
|
)
|
||||||
|
@ -13,7 +13,13 @@ from homeassistant.core import callback
|
|||||||
from homeassistant.data_entry_flow import FlowResult
|
from homeassistant.data_entry_flow import FlowResult
|
||||||
from homeassistant.helpers import config_entry_oauth2_flow
|
from homeassistant.helpers import config_entry_oauth2_flow
|
||||||
|
|
||||||
from .const import CONF_LANGUAGE_CODE, DEFAULT_NAME, DOMAIN, SUPPORTED_LANGUAGE_CODES
|
from .const import (
|
||||||
|
CONF_ENABLE_CONVERSATION_AGENT,
|
||||||
|
CONF_LANGUAGE_CODE,
|
||||||
|
DEFAULT_NAME,
|
||||||
|
DOMAIN,
|
||||||
|
SUPPORTED_LANGUAGE_CODES,
|
||||||
|
)
|
||||||
from .helpers import default_language_code
|
from .helpers import default_language_code
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
@ -108,6 +114,12 @@ class OptionsFlowHandler(config_entries.OptionsFlow):
|
|||||||
CONF_LANGUAGE_CODE,
|
CONF_LANGUAGE_CODE,
|
||||||
default=self.config_entry.options.get(CONF_LANGUAGE_CODE),
|
default=self.config_entry.options.get(CONF_LANGUAGE_CODE),
|
||||||
): vol.In(SUPPORTED_LANGUAGE_CODES),
|
): vol.In(SUPPORTED_LANGUAGE_CODES),
|
||||||
|
vol.Required(
|
||||||
|
CONF_ENABLE_CONVERSATION_AGENT,
|
||||||
|
default=self.config_entry.options.get(
|
||||||
|
CONF_ENABLE_CONVERSATION_AGENT
|
||||||
|
),
|
||||||
|
): bool,
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
@ -24,3 +24,5 @@ SUPPORTED_LANGUAGE_CODES: Final = [
|
|||||||
"ko-KR",
|
"ko-KR",
|
||||||
"pt-BR",
|
"pt-BR",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
CONF_ENABLE_CONVERSATION_AGENT: Final = "enable_conversation_agent"
|
||||||
|
@ -31,8 +31,10 @@
|
|||||||
"step": {
|
"step": {
|
||||||
"init": {
|
"init": {
|
||||||
"data": {
|
"data": {
|
||||||
|
"enable_conversation_agent": "Enable the conversation agent",
|
||||||
"language_code": "Language code"
|
"language_code": "Language code"
|
||||||
}
|
},
|
||||||
|
"description": "Set language for interactions with Google Assistant and whether you want to enable the conversation agent."
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
@ -34,8 +34,10 @@
|
|||||||
"step": {
|
"step": {
|
||||||
"init": {
|
"init": {
|
||||||
"data": {
|
"data": {
|
||||||
|
"enable_conversation_agent": "Enable conversation agent",
|
||||||
"language_code": "Language code"
|
"language_code": "Language code"
|
||||||
}
|
},
|
||||||
|
"description": "Set language for interactions with Google Assistant and whether you want to enable the conversation agent."
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -87,6 +87,10 @@ async def mock_setup_integration(
|
|||||||
class ExpectedCredentials:
|
class ExpectedCredentials:
|
||||||
"""Assert credentials have the expected access token."""
|
"""Assert credentials have the expected access token."""
|
||||||
|
|
||||||
|
def __init__(self, expected_access_token: str = ACCESS_TOKEN) -> None:
|
||||||
|
"""Initialize ExpectedCredentials."""
|
||||||
|
self.expected_access_token = expected_access_token
|
||||||
|
|
||||||
def __eq__(self, other: Credentials):
|
def __eq__(self, other: Credentials):
|
||||||
"""Return true if credentials have the expected access token."""
|
"""Return true if credentials have the expected access token."""
|
||||||
return other.token == ACCESS_TOKEN
|
return other.token == self.expected_access_token
|
||||||
|
@ -221,39 +221,65 @@ async def test_options_flow(
|
|||||||
assert result["type"] == "form"
|
assert result["type"] == "form"
|
||||||
assert result["step_id"] == "init"
|
assert result["step_id"] == "init"
|
||||||
data_schema = result["data_schema"].schema
|
data_schema = result["data_schema"].schema
|
||||||
assert set(data_schema) == {"language_code"}
|
assert set(data_schema) == {"enable_conversation_agent", "language_code"}
|
||||||
|
|
||||||
result = await hass.config_entries.options.async_configure(
|
result = await hass.config_entries.options.async_configure(
|
||||||
result["flow_id"],
|
result["flow_id"],
|
||||||
user_input={"language_code": "es-ES"},
|
user_input={"enable_conversation_agent": False, "language_code": "es-ES"},
|
||||||
)
|
)
|
||||||
assert result["type"] == "create_entry"
|
assert result["type"] == "create_entry"
|
||||||
assert config_entry.options == {"language_code": "es-ES"}
|
assert config_entry.options == {
|
||||||
|
"enable_conversation_agent": False,
|
||||||
|
"language_code": "es-ES",
|
||||||
|
}
|
||||||
|
|
||||||
# Retrigger options flow, not change language
|
# Retrigger options flow, not change language
|
||||||
result = await hass.config_entries.options.async_init(config_entry.entry_id)
|
result = await hass.config_entries.options.async_init(config_entry.entry_id)
|
||||||
assert result["type"] == "form"
|
assert result["type"] == "form"
|
||||||
assert result["step_id"] == "init"
|
assert result["step_id"] == "init"
|
||||||
data_schema = result["data_schema"].schema
|
data_schema = result["data_schema"].schema
|
||||||
assert set(data_schema) == {"language_code"}
|
assert set(data_schema) == {"enable_conversation_agent", "language_code"}
|
||||||
|
|
||||||
result = await hass.config_entries.options.async_configure(
|
result = await hass.config_entries.options.async_configure(
|
||||||
result["flow_id"],
|
result["flow_id"],
|
||||||
user_input={"language_code": "es-ES"},
|
user_input={"enable_conversation_agent": False, "language_code": "es-ES"},
|
||||||
)
|
)
|
||||||
assert result["type"] == "create_entry"
|
assert result["type"] == "create_entry"
|
||||||
assert config_entry.options == {"language_code": "es-ES"}
|
assert config_entry.options == {
|
||||||
|
"enable_conversation_agent": False,
|
||||||
|
"language_code": "es-ES",
|
||||||
|
}
|
||||||
|
|
||||||
# Retrigger options flow, change language
|
# Retrigger options flow, change language
|
||||||
result = await hass.config_entries.options.async_init(config_entry.entry_id)
|
result = await hass.config_entries.options.async_init(config_entry.entry_id)
|
||||||
assert result["type"] == "form"
|
assert result["type"] == "form"
|
||||||
assert result["step_id"] == "init"
|
assert result["step_id"] == "init"
|
||||||
data_schema = result["data_schema"].schema
|
data_schema = result["data_schema"].schema
|
||||||
assert set(data_schema) == {"language_code"}
|
assert set(data_schema) == {"enable_conversation_agent", "language_code"}
|
||||||
|
|
||||||
result = await hass.config_entries.options.async_configure(
|
result = await hass.config_entries.options.async_configure(
|
||||||
result["flow_id"],
|
result["flow_id"],
|
||||||
user_input={"language_code": "en-US"},
|
user_input={"enable_conversation_agent": False, "language_code": "en-US"},
|
||||||
)
|
)
|
||||||
assert result["type"] == "create_entry"
|
assert result["type"] == "create_entry"
|
||||||
assert config_entry.options == {"language_code": "en-US"}
|
assert config_entry.options == {
|
||||||
|
"enable_conversation_agent": False,
|
||||||
|
"language_code": "en-US",
|
||||||
|
}
|
||||||
|
|
||||||
|
# Retrigger options flow, enable conversation agent
|
||||||
|
result = await hass.config_entries.options.async_init(config_entry.entry_id)
|
||||||
|
assert result["type"] == "form"
|
||||||
|
assert result["step_id"] == "init"
|
||||||
|
data_schema = result["data_schema"].schema
|
||||||
|
assert set(data_schema) == {"enable_conversation_agent", "language_code"}
|
||||||
|
|
||||||
|
result = await hass.config_entries.options.async_configure(
|
||||||
|
result["flow_id"],
|
||||||
|
user_input={"enable_conversation_agent": True, "language_code": "en-US"},
|
||||||
|
)
|
||||||
|
assert result["type"] == "create_entry"
|
||||||
|
assert config_entry.options == {
|
||||||
|
"enable_conversation_agent": True,
|
||||||
|
"language_code": "en-US",
|
||||||
|
}
|
||||||
|
@ -9,6 +9,7 @@ import pytest
|
|||||||
from homeassistant.components.google_assistant_sdk import DOMAIN
|
from homeassistant.components.google_assistant_sdk import DOMAIN
|
||||||
from homeassistant.config_entries import ConfigEntryState
|
from homeassistant.config_entries import ConfigEntryState
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
from .conftest import ComponentSetup, ExpectedCredentials
|
from .conftest import ComponentSetup, ExpectedCredentials
|
||||||
|
|
||||||
@ -177,3 +178,107 @@ async def test_send_text_command_expired_token_refresh_failure(
|
|||||||
)
|
)
|
||||||
|
|
||||||
assert any(entry.async_get_active_flows(hass, {"reauth"})) == requires_reauth
|
assert any(entry.async_get_active_flows(hass, {"reauth"})) == requires_reauth
|
||||||
|
|
||||||
|
|
||||||
|
async def test_conversation_agent(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
setup_integration: ComponentSetup,
|
||||||
|
) -> None:
|
||||||
|
"""Test GoogleAssistantConversationAgent."""
|
||||||
|
await setup_integration()
|
||||||
|
|
||||||
|
assert await async_setup_component(hass, "conversation", {})
|
||||||
|
|
||||||
|
entries = hass.config_entries.async_entries(DOMAIN)
|
||||||
|
assert len(entries) == 1
|
||||||
|
entry = entries[0]
|
||||||
|
assert entry.state is ConfigEntryState.LOADED
|
||||||
|
hass.config_entries.async_update_entry(
|
||||||
|
entry, options={"enable_conversation_agent": True}
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
text1 = "tell me a joke"
|
||||||
|
text2 = "tell me another one"
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.google_assistant_sdk.TextAssistant"
|
||||||
|
) as mock_text_assistant:
|
||||||
|
await hass.services.async_call(
|
||||||
|
"conversation",
|
||||||
|
"process",
|
||||||
|
{"text": text1},
|
||||||
|
blocking=True,
|
||||||
|
)
|
||||||
|
await hass.services.async_call(
|
||||||
|
"conversation",
|
||||||
|
"process",
|
||||||
|
{"text": text2},
|
||||||
|
blocking=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert constructor is called only once since it's reused across requests
|
||||||
|
assert mock_text_assistant.call_count == 1
|
||||||
|
mock_text_assistant.assert_called_once_with(ExpectedCredentials(), "en-US")
|
||||||
|
mock_text_assistant.assert_has_calls([call().assist(text1)])
|
||||||
|
mock_text_assistant.assert_has_calls([call().assist(text2)])
|
||||||
|
|
||||||
|
|
||||||
|
async def test_conversation_agent_refresh_token(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
setup_integration: ComponentSetup,
|
||||||
|
aioclient_mock: AiohttpClientMocker,
|
||||||
|
) -> None:
|
||||||
|
"""Test GoogleAssistantConversationAgent when token is expired."""
|
||||||
|
await setup_integration()
|
||||||
|
|
||||||
|
assert await async_setup_component(hass, "conversation", {})
|
||||||
|
|
||||||
|
entries = hass.config_entries.async_entries(DOMAIN)
|
||||||
|
assert len(entries) == 1
|
||||||
|
entry = entries[0]
|
||||||
|
assert entry.state is ConfigEntryState.LOADED
|
||||||
|
hass.config_entries.async_update_entry(
|
||||||
|
entry, options={"enable_conversation_agent": True}
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
text1 = "tell me a joke"
|
||||||
|
text2 = "tell me another one"
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.google_assistant_sdk.TextAssistant"
|
||||||
|
) as mock_text_assistant:
|
||||||
|
await hass.services.async_call(
|
||||||
|
"conversation",
|
||||||
|
"process",
|
||||||
|
{"text": text1},
|
||||||
|
blocking=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Expire the token between requests
|
||||||
|
entry.data["token"]["expires_at"] = time.time() - 3600
|
||||||
|
updated_access_token = "updated-access-token"
|
||||||
|
aioclient_mock.post(
|
||||||
|
"https://oauth2.googleapis.com/token",
|
||||||
|
json={
|
||||||
|
"access_token": updated_access_token,
|
||||||
|
"refresh_token": "updated-refresh-token",
|
||||||
|
"expires_at": time.time() + 3600,
|
||||||
|
"expires_in": 3600,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
await hass.services.async_call(
|
||||||
|
"conversation",
|
||||||
|
"process",
|
||||||
|
{"text": text2},
|
||||||
|
blocking=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Assert constructor is called twice since the token was expired
|
||||||
|
assert mock_text_assistant.call_count == 2
|
||||||
|
mock_text_assistant.assert_has_calls([call(ExpectedCredentials(), "en-US")])
|
||||||
|
mock_text_assistant.assert_has_calls(
|
||||||
|
[call(ExpectedCredentials(updated_access_token), "en-US")]
|
||||||
|
)
|
||||||
|
mock_text_assistant.assert_has_calls([call().assist(text1)])
|
||||||
|
mock_text_assistant.assert_has_calls([call().assist(text2)])
|
||||||
|
Loading…
x
Reference in New Issue
Block a user