Google Assistant SDK conversation agent (#85499)

* Google Assistant SDK conversation agent

* refresh token

* fix session

* Add tests

* Add option to enable conversation agent
This commit is contained in:
tronikos 2023-01-09 17:53:41 -08:00 committed by GitHub
parent f2df72e014
commit e24989b446
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 238 additions and 18 deletions

View File

@ -2,21 +2,24 @@
from __future__ import annotations from __future__ import annotations
import aiohttp import aiohttp
from gassist_text import TextAssistant
from google.oauth2.credentials import Credentials
import voluptuous as vol import voluptuous as vol
from homeassistant.components import conversation
from homeassistant.config_entries import ConfigEntry, ConfigEntryState from homeassistant.config_entries import ConfigEntry, ConfigEntryState
from homeassistant.const import CONF_NAME, Platform from homeassistant.const import CONF_ACCESS_TOKEN, CONF_NAME, Platform
from homeassistant.core import HomeAssistant, ServiceCall from homeassistant.core import Context, HomeAssistant, ServiceCall
from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady
from homeassistant.helpers import discovery from homeassistant.helpers import discovery, intent
from homeassistant.helpers.config_entry_oauth2_flow import ( from homeassistant.helpers.config_entry_oauth2_flow import (
OAuth2Session, OAuth2Session,
async_get_config_entry_implementation, async_get_config_entry_implementation,
) )
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from .const import DOMAIN from .const import CONF_ENABLE_CONVERSATION_AGENT, CONF_LANGUAGE_CODE, DOMAIN
from .helpers import async_send_text_commands from .helpers import async_send_text_commands, default_language_code
SERVICE_SEND_TEXT_COMMAND = "send_text_command" SERVICE_SEND_TEXT_COMMAND = "send_text_command"
SERVICE_SEND_TEXT_COMMAND_FIELD_COMMAND = "command" SERVICE_SEND_TEXT_COMMAND_FIELD_COMMAND = "command"
@ -58,6 +61,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
await async_setup_service(hass) await async_setup_service(hass)
entry.async_on_unload(entry.add_update_listener(update_listener))
await update_listener(hass, entry)
return True return True
@ -90,3 +96,64 @@ async def async_setup_service(hass: HomeAssistant) -> None:
send_text_command, send_text_command,
schema=SERVICE_SEND_TEXT_COMMAND_SCHEMA, schema=SERVICE_SEND_TEXT_COMMAND_SCHEMA,
) )
async def update_listener(hass, entry):
"""Handle options update."""
if entry.options.get(CONF_ENABLE_CONVERSATION_AGENT, False):
agent = GoogleAssistantConversationAgent(hass, entry)
conversation.async_set_agent(hass, agent)
else:
conversation.async_set_agent(hass, None)
class GoogleAssistantConversationAgent(conversation.AbstractConversationAgent):
"""Google Assistant SDK conversation agent."""
def __init__(self, hass: HomeAssistant, entry: ConfigEntry) -> None:
"""Initialize the agent."""
self.hass = hass
self.entry = entry
self.assistant: TextAssistant | None = None
self.session: OAuth2Session | None = None
@property
def attribution(self):
"""Return the attribution."""
return {
"name": "Powered by Google Assistant SDK",
"url": "https://www.home-assistant.io/integrations/google_assistant_sdk/",
}
async def async_process(
self,
text: str,
context: Context,
conversation_id: str | None = None,
language: str | None = None,
) -> conversation.ConversationResult | None:
"""Process a sentence."""
if self.session:
session = self.session
else:
session = self.hass.data[DOMAIN].get(self.entry.entry_id)
self.session = session
if not session.valid_token:
await session.async_ensure_token_valid()
self.assistant = None
if not self.assistant:
credentials = Credentials(session.token[CONF_ACCESS_TOKEN])
language_code = self.entry.options.get(
CONF_LANGUAGE_CODE, default_language_code(self.hass)
)
self.assistant = TextAssistant(credentials, language_code)
resp = self.assistant.assist(text)
text_response = resp[0]
language = language or self.hass.config.language
intent_response = intent.IntentResponse(language=language)
intent_response.async_set_speech(text_response)
return conversation.ConversationResult(
response=intent_response, conversation_id=conversation_id
)

View File

@ -13,7 +13,13 @@ from homeassistant.core import callback
from homeassistant.data_entry_flow import FlowResult from homeassistant.data_entry_flow import FlowResult
from homeassistant.helpers import config_entry_oauth2_flow from homeassistant.helpers import config_entry_oauth2_flow
from .const import CONF_LANGUAGE_CODE, DEFAULT_NAME, DOMAIN, SUPPORTED_LANGUAGE_CODES from .const import (
CONF_ENABLE_CONVERSATION_AGENT,
CONF_LANGUAGE_CODE,
DEFAULT_NAME,
DOMAIN,
SUPPORTED_LANGUAGE_CODES,
)
from .helpers import default_language_code from .helpers import default_language_code
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -108,6 +114,12 @@ class OptionsFlowHandler(config_entries.OptionsFlow):
CONF_LANGUAGE_CODE, CONF_LANGUAGE_CODE,
default=self.config_entry.options.get(CONF_LANGUAGE_CODE), default=self.config_entry.options.get(CONF_LANGUAGE_CODE),
): vol.In(SUPPORTED_LANGUAGE_CODES), ): vol.In(SUPPORTED_LANGUAGE_CODES),
vol.Required(
CONF_ENABLE_CONVERSATION_AGENT,
default=self.config_entry.options.get(
CONF_ENABLE_CONVERSATION_AGENT
),
): bool,
} }
), ),
) )

View File

@ -24,3 +24,5 @@ SUPPORTED_LANGUAGE_CODES: Final = [
"ko-KR", "ko-KR",
"pt-BR", "pt-BR",
] ]
CONF_ENABLE_CONVERSATION_AGENT: Final = "enable_conversation_agent"

View File

@ -31,8 +31,10 @@
"step": { "step": {
"init": { "init": {
"data": { "data": {
"enable_conversation_agent": "Enable the conversation agent",
"language_code": "Language code" "language_code": "Language code"
} },
"description": "Set language for interactions with Google Assistant and whether you want to enable the conversation agent."
} }
} }
}, },

View File

@ -34,8 +34,10 @@
"step": { "step": {
"init": { "init": {
"data": { "data": {
"enable_conversation_agent": "Enable conversation agent",
"language_code": "Language code" "language_code": "Language code"
} },
"description": "Set language for interactions with Google Assistant and whether you want to enable the conversation agent."
} }
} }
} }

View File

@ -87,6 +87,10 @@ async def mock_setup_integration(
class ExpectedCredentials: class ExpectedCredentials:
"""Assert credentials have the expected access token.""" """Assert credentials have the expected access token."""
def __init__(self, expected_access_token: str = ACCESS_TOKEN) -> None:
"""Initialize ExpectedCredentials."""
self.expected_access_token = expected_access_token
def __eq__(self, other: Credentials): def __eq__(self, other: Credentials):
"""Return true if credentials have the expected access token.""" """Return true if credentials have the expected access token."""
return other.token == ACCESS_TOKEN return other.token == self.expected_access_token

View File

@ -221,39 +221,65 @@ async def test_options_flow(
assert result["type"] == "form" assert result["type"] == "form"
assert result["step_id"] == "init" assert result["step_id"] == "init"
data_schema = result["data_schema"].schema data_schema = result["data_schema"].schema
assert set(data_schema) == {"language_code"} assert set(data_schema) == {"enable_conversation_agent", "language_code"}
result = await hass.config_entries.options.async_configure( result = await hass.config_entries.options.async_configure(
result["flow_id"], result["flow_id"],
user_input={"language_code": "es-ES"}, user_input={"enable_conversation_agent": False, "language_code": "es-ES"},
) )
assert result["type"] == "create_entry" assert result["type"] == "create_entry"
assert config_entry.options == {"language_code": "es-ES"} assert config_entry.options == {
"enable_conversation_agent": False,
"language_code": "es-ES",
}
# Retrigger options flow, not change language # Retrigger options flow, not change language
result = await hass.config_entries.options.async_init(config_entry.entry_id) result = await hass.config_entries.options.async_init(config_entry.entry_id)
assert result["type"] == "form" assert result["type"] == "form"
assert result["step_id"] == "init" assert result["step_id"] == "init"
data_schema = result["data_schema"].schema data_schema = result["data_schema"].schema
assert set(data_schema) == {"language_code"} assert set(data_schema) == {"enable_conversation_agent", "language_code"}
result = await hass.config_entries.options.async_configure( result = await hass.config_entries.options.async_configure(
result["flow_id"], result["flow_id"],
user_input={"language_code": "es-ES"}, user_input={"enable_conversation_agent": False, "language_code": "es-ES"},
) )
assert result["type"] == "create_entry" assert result["type"] == "create_entry"
assert config_entry.options == {"language_code": "es-ES"} assert config_entry.options == {
"enable_conversation_agent": False,
"language_code": "es-ES",
}
# Retrigger options flow, change language # Retrigger options flow, change language
result = await hass.config_entries.options.async_init(config_entry.entry_id) result = await hass.config_entries.options.async_init(config_entry.entry_id)
assert result["type"] == "form" assert result["type"] == "form"
assert result["step_id"] == "init" assert result["step_id"] == "init"
data_schema = result["data_schema"].schema data_schema = result["data_schema"].schema
assert set(data_schema) == {"language_code"} assert set(data_schema) == {"enable_conversation_agent", "language_code"}
result = await hass.config_entries.options.async_configure( result = await hass.config_entries.options.async_configure(
result["flow_id"], result["flow_id"],
user_input={"language_code": "en-US"}, user_input={"enable_conversation_agent": False, "language_code": "en-US"},
) )
assert result["type"] == "create_entry" assert result["type"] == "create_entry"
assert config_entry.options == {"language_code": "en-US"} assert config_entry.options == {
"enable_conversation_agent": False,
"language_code": "en-US",
}
# Retrigger options flow, enable conversation agent
result = await hass.config_entries.options.async_init(config_entry.entry_id)
assert result["type"] == "form"
assert result["step_id"] == "init"
data_schema = result["data_schema"].schema
assert set(data_schema) == {"enable_conversation_agent", "language_code"}
result = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input={"enable_conversation_agent": True, "language_code": "en-US"},
)
assert result["type"] == "create_entry"
assert config_entry.options == {
"enable_conversation_agent": True,
"language_code": "en-US",
}

View File

@ -9,6 +9,7 @@ import pytest
from homeassistant.components.google_assistant_sdk import DOMAIN from homeassistant.components.google_assistant_sdk import DOMAIN
from homeassistant.config_entries import ConfigEntryState from homeassistant.config_entries import ConfigEntryState
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
from .conftest import ComponentSetup, ExpectedCredentials from .conftest import ComponentSetup, ExpectedCredentials
@ -177,3 +178,107 @@ async def test_send_text_command_expired_token_refresh_failure(
) )
assert any(entry.async_get_active_flows(hass, {"reauth"})) == requires_reauth assert any(entry.async_get_active_flows(hass, {"reauth"})) == requires_reauth
async def test_conversation_agent(
hass: HomeAssistant,
setup_integration: ComponentSetup,
) -> None:
"""Test GoogleAssistantConversationAgent."""
await setup_integration()
assert await async_setup_component(hass, "conversation", {})
entries = hass.config_entries.async_entries(DOMAIN)
assert len(entries) == 1
entry = entries[0]
assert entry.state is ConfigEntryState.LOADED
hass.config_entries.async_update_entry(
entry, options={"enable_conversation_agent": True}
)
await hass.async_block_till_done()
text1 = "tell me a joke"
text2 = "tell me another one"
with patch(
"homeassistant.components.google_assistant_sdk.TextAssistant"
) as mock_text_assistant:
await hass.services.async_call(
"conversation",
"process",
{"text": text1},
blocking=True,
)
await hass.services.async_call(
"conversation",
"process",
{"text": text2},
blocking=True,
)
# Assert constructor is called only once since it's reused across requests
assert mock_text_assistant.call_count == 1
mock_text_assistant.assert_called_once_with(ExpectedCredentials(), "en-US")
mock_text_assistant.assert_has_calls([call().assist(text1)])
mock_text_assistant.assert_has_calls([call().assist(text2)])
async def test_conversation_agent_refresh_token(
hass: HomeAssistant,
setup_integration: ComponentSetup,
aioclient_mock: AiohttpClientMocker,
) -> None:
"""Test GoogleAssistantConversationAgent when token is expired."""
await setup_integration()
assert await async_setup_component(hass, "conversation", {})
entries = hass.config_entries.async_entries(DOMAIN)
assert len(entries) == 1
entry = entries[0]
assert entry.state is ConfigEntryState.LOADED
hass.config_entries.async_update_entry(
entry, options={"enable_conversation_agent": True}
)
await hass.async_block_till_done()
text1 = "tell me a joke"
text2 = "tell me another one"
with patch(
"homeassistant.components.google_assistant_sdk.TextAssistant"
) as mock_text_assistant:
await hass.services.async_call(
"conversation",
"process",
{"text": text1},
blocking=True,
)
# Expire the token between requests
entry.data["token"]["expires_at"] = time.time() - 3600
updated_access_token = "updated-access-token"
aioclient_mock.post(
"https://oauth2.googleapis.com/token",
json={
"access_token": updated_access_token,
"refresh_token": "updated-refresh-token",
"expires_at": time.time() + 3600,
"expires_in": 3600,
},
)
await hass.services.async_call(
"conversation",
"process",
{"text": text2},
blocking=True,
)
# Assert constructor is called twice since the token was expired
assert mock_text_assistant.call_count == 2
mock_text_assistant.assert_has_calls([call(ExpectedCredentials(), "en-US")])
mock_text_assistant.assert_has_calls(
[call(ExpectedCredentials(updated_access_token), "en-US")]
)
mock_text_assistant.assert_has_calls([call().assist(text1)])
mock_text_assistant.assert_has_calls([call().assist(text2)])