Google Assistant SDK: support audio response playback (#85989)

* Google Assistant SDK: support response playback

* Update PATHS_WITHOUT_AUTH

* gassist-text==0.0.8

* address review comments
This commit is contained in:
tronikos 2023-01-24 08:19:23 -08:00 committed by GitHub
parent 80a8da26bc
commit 0daaa37e09
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 244 additions and 24 deletions

View File

@ -11,23 +11,36 @@ from homeassistant.config_entries import ConfigEntry, ConfigEntryState
from homeassistant.const import CONF_ACCESS_TOKEN, CONF_NAME, Platform
from homeassistant.core import Context, HomeAssistant, ServiceCall
from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady
from homeassistant.helpers import discovery, intent
from homeassistant.helpers import config_validation as cv, discovery, intent
from homeassistant.helpers.config_entry_oauth2_flow import (
OAuth2Session,
async_get_config_entry_implementation,
)
from homeassistant.helpers.typing import ConfigType
from .const import CONF_ENABLE_CONVERSATION_AGENT, CONF_LANGUAGE_CODE, DOMAIN
from .helpers import async_send_text_commands, default_language_code
from .const import (
CONF_ENABLE_CONVERSATION_AGENT,
CONF_LANGUAGE_CODE,
DATA_MEM_STORAGE,
DATA_SESSION,
DOMAIN,
)
from .helpers import (
GoogleAssistantSDKAudioView,
InMemoryStorage,
async_send_text_commands,
default_language_code,
)
SERVICE_SEND_TEXT_COMMAND = "send_text_command"
SERVICE_SEND_TEXT_COMMAND_FIELD_COMMAND = "command"
SERVICE_SEND_TEXT_COMMAND_FIELD_MEDIA_PLAYER = "media_player"
SERVICE_SEND_TEXT_COMMAND_SCHEMA = vol.All(
{
vol.Required(SERVICE_SEND_TEXT_COMMAND_FIELD_COMMAND): vol.All(
str, vol.Length(min=1)
),
vol.Optional(SERVICE_SEND_TEXT_COMMAND_FIELD_MEDIA_PLAYER): cv.comp_entity_ids,
},
)
@ -45,6 +58,8 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up Google Assistant SDK from a config entry."""
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = {}
implementation = await async_get_config_entry_implementation(hass, entry)
session = OAuth2Session(hass, entry, implementation)
try:
@ -57,7 +72,11 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
raise ConfigEntryNotReady from err
except aiohttp.ClientError as err:
raise ConfigEntryNotReady from err
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = session
hass.data[DOMAIN][entry.entry_id][DATA_SESSION] = session
mem_storage = InMemoryStorage(hass)
hass.data[DOMAIN][entry.entry_id][DATA_MEM_STORAGE] = mem_storage
hass.http.register_view(GoogleAssistantSDKAudioView(mem_storage))
await async_setup_service(hass)
@ -88,7 +107,10 @@ async def async_setup_service(hass: HomeAssistant) -> None:
async def send_text_command(call: ServiceCall) -> None:
"""Send a text command to Google Assistant SDK."""
command: str = call.data[SERVICE_SEND_TEXT_COMMAND_FIELD_COMMAND]
await async_send_text_commands([command], hass)
media_players: list[str] | None = call.data.get(
SERVICE_SEND_TEXT_COMMAND_FIELD_MEDIA_PLAYER
)
await async_send_text_commands(hass, [command], media_players)
hass.services.async_register(
DOMAIN,
@ -136,7 +158,7 @@ class GoogleAssistantConversationAgent(conversation.AbstractConversationAgent):
if self.session:
session = self.session
else:
session = self.hass.data[DOMAIN].get(self.entry.entry_id)
session = self.hass.data[DOMAIN][self.entry.entry_id][DATA_SESSION]
self.session = session
if not session.valid_token:
await session.async_ensure_token_valid()

View File

@ -5,8 +5,12 @@ DOMAIN: Final = "google_assistant_sdk"
DEFAULT_NAME: Final = "Google Assistant SDK"
CONF_ENABLE_CONVERSATION_AGENT: Final = "enable_conversation_agent"
CONF_LANGUAGE_CODE: Final = "language_code"
DATA_MEM_STORAGE: Final = "mem_storage"
DATA_SESSION: Final = "session"
# https://developers.google.com/assistant/sdk/reference/rpc/languages
SUPPORTED_LANGUAGE_CODES: Final = [
"de-DE",
@ -24,5 +28,3 @@ SUPPORTED_LANGUAGE_CODES: Final = [
"ko-KR",
"pt-BR",
]
CONF_ENABLE_CONVERSATION_AGENT: Final = "enable_conversation_agent"

View File

@ -1,18 +1,38 @@
"""Helper classes for Google Assistant SDK integration."""
from __future__ import annotations
from http import HTTPStatus
import logging
from typing import Any
import uuid
import aiohttp
from aiohttp import web
from gassist_text import TextAssistant
from google.oauth2.credentials import Credentials
from homeassistant.components.http import HomeAssistantView
from homeassistant.components.media_player import (
ATTR_MEDIA_ANNOUNCE,
ATTR_MEDIA_CONTENT_ID,
ATTR_MEDIA_CONTENT_TYPE,
DOMAIN as DOMAIN_MP,
SERVICE_PLAY_MEDIA,
MediaType,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_ACCESS_TOKEN
from homeassistant.const import ATTR_ENTITY_ID, CONF_ACCESS_TOKEN
from homeassistant.core import HomeAssistant
from homeassistant.helpers.config_entry_oauth2_flow import OAuth2Session
from homeassistant.helpers.event import async_call_later
from .const import CONF_LANGUAGE_CODE, DOMAIN, SUPPORTED_LANGUAGE_CODES
from .const import (
CONF_LANGUAGE_CODE,
DATA_MEM_STORAGE,
DATA_SESSION,
DOMAIN,
SUPPORTED_LANGUAGE_CODES,
)
_LOGGER = logging.getLogger(__name__)
@ -28,12 +48,14 @@ DEFAULT_LANGUAGE_CODES = {
}
async def async_send_text_commands(commands: list[str], hass: HomeAssistant) -> None:
async def async_send_text_commands(
hass: HomeAssistant, commands: list[str], media_players: list[str] | None = None
) -> None:
"""Send text commands to Google Assistant Service."""
# There can only be 1 entry (config_flow has single_instance_allowed)
entry: ConfigEntry = hass.config_entries.async_entries(DOMAIN)[0]
session: OAuth2Session = hass.data[DOMAIN].get(entry.entry_id)
session: OAuth2Session = hass.data[DOMAIN][entry.entry_id][DATA_SESSION]
try:
await session.async_ensure_token_valid()
except aiohttp.ClientResponseError as err:
@ -43,10 +65,32 @@ async def async_send_text_commands(commands: list[str], hass: HomeAssistant) ->
credentials = Credentials(session.token[CONF_ACCESS_TOKEN])
language_code = entry.options.get(CONF_LANGUAGE_CODE, default_language_code(hass))
with TextAssistant(credentials, language_code) as assistant:
with TextAssistant(
credentials, language_code, audio_out=bool(media_players)
) as assistant:
for command in commands:
text_response = assistant.assist(command)[0]
resp = assistant.assist(command)
text_response = resp[0]
_LOGGER.debug("command: %s\nresponse: %s", command, text_response)
audio_response = resp[2]
if media_players and audio_response:
mem_storage: InMemoryStorage = hass.data[DOMAIN][entry.entry_id][
DATA_MEM_STORAGE
]
audio_url = GoogleAssistantSDKAudioView.url.format(
filename=mem_storage.store_and_get_identifier(audio_response)
)
await hass.services.async_call(
DOMAIN_MP,
SERVICE_PLAY_MEDIA,
{
ATTR_ENTITY_ID: media_players,
ATTR_MEDIA_CONTENT_ID: audio_url,
ATTR_MEDIA_CONTENT_TYPE: MediaType.MUSIC,
ATTR_MEDIA_ANNOUNCE: True,
},
blocking=True,
)
def default_language_code(hass: HomeAssistant):
@ -55,3 +99,53 @@ def default_language_code(hass: HomeAssistant):
if language_code in SUPPORTED_LANGUAGE_CODES:
return language_code
return DEFAULT_LANGUAGE_CODES.get(hass.config.language, "en-US")
class InMemoryStorage:
"""Temporarily store and retrieve data from in memory storage."""
def __init__(self, hass: HomeAssistant) -> None:
"""Initialize InMemoryStorage."""
self.hass: HomeAssistant = hass
self.mem: dict[str, bytes] = {}
def store_and_get_identifier(self, data: bytes) -> str:
"""
Temporarily store data and return identifier to be able to retrieve it.
Data expires after 5 minutes.
"""
identifier: str = uuid.uuid1().hex
self.mem[identifier] = data
def async_remove_from_mem(*_: Any) -> None:
"""Cleanup memory."""
self.mem.pop(identifier, None)
# Remove the entry from memory 5 minutes later
async_call_later(self.hass, 5 * 60, async_remove_from_mem)
return identifier
def retrieve(self, identifier: str) -> bytes | None:
"""Retrieve previously stored data."""
return self.mem.get(identifier)
class GoogleAssistantSDKAudioView(HomeAssistantView):
"""Google Assistant SDK view to serve audio responses."""
requires_auth = True
url = "/api/google_assistant_sdk/audio/{filename}"
name = "api:google_assistant_sdk:audio"
def __init__(self, mem_storage: InMemoryStorage) -> None:
"""Initialize GoogleAssistantSDKView."""
self.mem_storage: InMemoryStorage = mem_storage
async def get(self, request: web.Request, filename: str) -> web.Response:
"""Start a get request."""
audio = self.mem_storage.retrieve(filename)
if not audio:
return web.Response(status=HTTPStatus.NOT_FOUND)
return web.Response(body=audio, content_type="audio/mpeg")

View File

@ -2,9 +2,9 @@
"domain": "google_assistant_sdk",
"name": "Google Assistant SDK",
"config_flow": true,
"dependencies": ["application_credentials"],
"dependencies": ["application_credentials", "http"],
"documentation": "https://www.home-assistant.io/integrations/google_assistant_sdk/",
"requirements": ["gassist-text==0.0.7"],
"requirements": ["gassist-text==0.0.8"],
"codeowners": ["@tronikos"],
"iot_class": "cloud_polling",
"integration_type": "service"

View File

@ -70,4 +70,4 @@ class BroadcastNotificationService(BaseNotificationService):
commands.append(
broadcast_commands(language_code)[1].format(message, target)
)
await async_send_text_commands(commands, self.hass)
await async_send_text_commands(self.hass, commands)

View File

@ -8,3 +8,10 @@ send_text_command:
example: turn off kitchen TV
selector:
text:
media_player:
name: Media Player Entity
description: Name(s) of media player entities to play response on
example: media_player.living_room_speaker
selector:
entity:
domain: media_player

View File

@ -754,7 +754,7 @@ fritzconnection==1.10.3
gTTS==2.2.4
# homeassistant.components.google_assistant_sdk
gassist-text==0.0.7
gassist-text==0.0.8
# homeassistant.components.google
gcal-sync==4.1.2

View File

@ -573,7 +573,7 @@ fritzconnection==1.10.3
gTTS==2.2.4
# homeassistant.components.google_assistant_sdk
gassist-text==0.0.7
gassist-text==0.0.8
# homeassistant.components.google
gcal-sync==4.1.2

View File

@ -1,4 +1,5 @@
"""Tests for Google Assistant SDK."""
from datetime import timedelta
import http
import time
from unittest.mock import call, patch
@ -10,12 +11,22 @@ from homeassistant.components.google_assistant_sdk import DOMAIN
from homeassistant.config_entries import ConfigEntryState
from homeassistant.core import HomeAssistant
from homeassistant.setup import async_setup_component
from homeassistant.util.dt import utcnow
from .conftest import ComponentSetup, ExpectedCredentials
from tests.common import async_fire_time_changed, async_mock_service
from tests.test_util.aiohttp import AiohttpClientMocker
async def fetch_api_url(hass_client, url):
"""Fetch an API URL and return HTTP status and contents."""
client = await hass_client()
response = await client.get(url)
contents = await response.read()
return response.status, contents
async def test_setup_success(
hass: HomeAssistant, setup_integration: ComponentSetup
) -> None:
@ -129,7 +140,7 @@ async def test_send_text_command(
blocking=True,
)
mock_text_assistant.assert_called_once_with(
ExpectedCredentials(), expected_language_code
ExpectedCredentials(), expected_language_code, audio_out=False
)
mock_text_assistant.assert_has_calls([call().__enter__().assist(command)])
@ -180,6 +191,88 @@ async def test_send_text_command_expired_token_refresh_failure(
assert any(entry.async_get_active_flows(hass, {"reauth"})) == requires_reauth
async def test_send_text_command_media_player(
hass: HomeAssistant, setup_integration: ComponentSetup, hass_client
) -> None:
"""Test send_text_command with media_player."""
await setup_integration()
play_media_calls = async_mock_service(hass, "media_player", "play_media")
command = "tell me a joke"
media_player = "media_player.office_speaker"
audio_response1 = b"joke1 audio response bytes"
audio_response2 = b"joke2 audio response bytes"
with patch(
"homeassistant.components.google_assistant_sdk.helpers.TextAssistant.assist",
side_effect=[
("joke1 text", None, audio_response1),
("joke2 text", None, audio_response2),
],
) as mock_assist_call:
# Run the same command twice, getting different audio response each time.
await hass.services.async_call(
DOMAIN,
"send_text_command",
{
"command": command,
"media_player": media_player,
},
blocking=True,
)
await hass.services.async_call(
DOMAIN,
"send_text_command",
{
"command": command,
"media_player": media_player,
},
blocking=True,
)
mock_assist_call.assert_has_calls([call(command), call(command)])
assert len(play_media_calls) == 2
for play_media_call in play_media_calls:
assert play_media_call.data["entity_id"] == [media_player]
assert play_media_call.data["media_content_id"].startswith(
"/api/google_assistant_sdk/audio/"
)
audio_url1 = play_media_calls[0].data["media_content_id"]
audio_url2 = play_media_calls[1].data["media_content_id"]
assert audio_url1 != audio_url2
# Assert that both audio responses can be served
status, response = await fetch_api_url(hass_client, audio_url1)
assert status == http.HTTPStatus.OK
assert response == audio_response1
status, response = await fetch_api_url(hass_client, audio_url2)
assert status == http.HTTPStatus.OK
assert response == audio_response2
# Assert a nonexistent URL returns 404
status, _ = await fetch_api_url(
hass_client, "/api/google_assistant_sdk/audio/nonexistent"
)
assert status == http.HTTPStatus.NOT_FOUND
# Assert that both audio responses can still be served before the 5 minutes expiration
async_fire_time_changed(hass, utcnow() + timedelta(minutes=4))
status, response = await fetch_api_url(hass_client, audio_url1)
assert status == http.HTTPStatus.OK
assert response == audio_response1
status, response = await fetch_api_url(hass_client, audio_url2)
assert status == http.HTTPStatus.OK
assert response == audio_response2
# Assert that they cannot be served after the 5 minutes expiration
async_fire_time_changed(hass, utcnow() + timedelta(minutes=6))
status, response = await fetch_api_url(hass_client, audio_url1)
assert status == http.HTTPStatus.NOT_FOUND
status, response = await fetch_api_url(hass_client, audio_url2)
assert status == http.HTTPStatus.NOT_FOUND
async def test_conversation_agent(
hass: HomeAssistant,
setup_integration: ComponentSetup,

View File

@ -44,7 +44,9 @@ async def test_broadcast_no_targets(
{notify.ATTR_MESSAGE: message},
)
await hass.async_block_till_done()
mock_text_assistant.assert_called_once_with(ExpectedCredentials(), language_code)
mock_text_assistant.assert_called_once_with(
ExpectedCredentials(), language_code, audio_out=False
)
mock_text_assistant.assert_has_calls([call().__enter__().assist(expected_command)])
@ -84,7 +86,7 @@ async def test_broadcast_one_target(
with patch(
"homeassistant.components.google_assistant_sdk.helpers.TextAssistant.assist",
return_value=["text_response", None],
return_value=("text_response", None, b""),
) as mock_assist_call:
await hass.services.async_call(
notify.DOMAIN,
@ -108,7 +110,7 @@ async def test_broadcast_two_targets(
expected_command2 = "broadcast to master bedroom time for dinner"
with patch(
"homeassistant.components.google_assistant_sdk.helpers.TextAssistant.assist",
return_value=["text_response", None],
return_value=("text_response", None, b""),
) as mock_assist_call:
await hass.services.async_call(
notify.DOMAIN,
@ -129,7 +131,7 @@ async def test_broadcast_empty_message(
with patch(
"homeassistant.components.google_assistant_sdk.helpers.TextAssistant.assist",
return_value=["text_response", None],
return_value=("text_response", None, b""),
) as mock_assist_call:
await hass.services.async_call(
notify.DOMAIN,