mirror of
https://github.com/home-assistant/core.git
synced 2025-07-13 16:27:08 +00:00
Google Assistant SDK: support audio response playback (#85989)
* Google Assistant SDK: support response playback * Update PATHS_WITHOUT_AUTH * gassist-text==0.0.8 * address review comments
This commit is contained in:
parent
80a8da26bc
commit
0daaa37e09
@ -11,23 +11,36 @@ from homeassistant.config_entries import ConfigEntry, ConfigEntryState
|
||||
from homeassistant.const import CONF_ACCESS_TOKEN, CONF_NAME, Platform
|
||||
from homeassistant.core import Context, HomeAssistant, ServiceCall
|
||||
from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady
|
||||
from homeassistant.helpers import discovery, intent
|
||||
from homeassistant.helpers import config_validation as cv, discovery, intent
|
||||
from homeassistant.helpers.config_entry_oauth2_flow import (
|
||||
OAuth2Session,
|
||||
async_get_config_entry_implementation,
|
||||
)
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
from .const import CONF_ENABLE_CONVERSATION_AGENT, CONF_LANGUAGE_CODE, DOMAIN
|
||||
from .helpers import async_send_text_commands, default_language_code
|
||||
from .const import (
|
||||
CONF_ENABLE_CONVERSATION_AGENT,
|
||||
CONF_LANGUAGE_CODE,
|
||||
DATA_MEM_STORAGE,
|
||||
DATA_SESSION,
|
||||
DOMAIN,
|
||||
)
|
||||
from .helpers import (
|
||||
GoogleAssistantSDKAudioView,
|
||||
InMemoryStorage,
|
||||
async_send_text_commands,
|
||||
default_language_code,
|
||||
)
|
||||
|
||||
SERVICE_SEND_TEXT_COMMAND = "send_text_command"
|
||||
SERVICE_SEND_TEXT_COMMAND_FIELD_COMMAND = "command"
|
||||
SERVICE_SEND_TEXT_COMMAND_FIELD_MEDIA_PLAYER = "media_player"
|
||||
SERVICE_SEND_TEXT_COMMAND_SCHEMA = vol.All(
|
||||
{
|
||||
vol.Required(SERVICE_SEND_TEXT_COMMAND_FIELD_COMMAND): vol.All(
|
||||
str, vol.Length(min=1)
|
||||
),
|
||||
vol.Optional(SERVICE_SEND_TEXT_COMMAND_FIELD_MEDIA_PLAYER): cv.comp_entity_ids,
|
||||
},
|
||||
)
|
||||
|
||||
@ -45,6 +58,8 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
|
||||
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
"""Set up Google Assistant SDK from a config entry."""
|
||||
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = {}
|
||||
|
||||
implementation = await async_get_config_entry_implementation(hass, entry)
|
||||
session = OAuth2Session(hass, entry, implementation)
|
||||
try:
|
||||
@ -57,7 +72,11 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
|
||||
raise ConfigEntryNotReady from err
|
||||
except aiohttp.ClientError as err:
|
||||
raise ConfigEntryNotReady from err
|
||||
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = session
|
||||
hass.data[DOMAIN][entry.entry_id][DATA_SESSION] = session
|
||||
|
||||
mem_storage = InMemoryStorage(hass)
|
||||
hass.data[DOMAIN][entry.entry_id][DATA_MEM_STORAGE] = mem_storage
|
||||
hass.http.register_view(GoogleAssistantSDKAudioView(mem_storage))
|
||||
|
||||
await async_setup_service(hass)
|
||||
|
||||
@ -88,7 +107,10 @@ async def async_setup_service(hass: HomeAssistant) -> None:
|
||||
async def send_text_command(call: ServiceCall) -> None:
|
||||
"""Send a text command to Google Assistant SDK."""
|
||||
command: str = call.data[SERVICE_SEND_TEXT_COMMAND_FIELD_COMMAND]
|
||||
await async_send_text_commands([command], hass)
|
||||
media_players: list[str] | None = call.data.get(
|
||||
SERVICE_SEND_TEXT_COMMAND_FIELD_MEDIA_PLAYER
|
||||
)
|
||||
await async_send_text_commands(hass, [command], media_players)
|
||||
|
||||
hass.services.async_register(
|
||||
DOMAIN,
|
||||
@ -136,7 +158,7 @@ class GoogleAssistantConversationAgent(conversation.AbstractConversationAgent):
|
||||
if self.session:
|
||||
session = self.session
|
||||
else:
|
||||
session = self.hass.data[DOMAIN].get(self.entry.entry_id)
|
||||
session = self.hass.data[DOMAIN][self.entry.entry_id][DATA_SESSION]
|
||||
self.session = session
|
||||
if not session.valid_token:
|
||||
await session.async_ensure_token_valid()
|
||||
|
@ -5,8 +5,12 @@ DOMAIN: Final = "google_assistant_sdk"
|
||||
|
||||
DEFAULT_NAME: Final = "Google Assistant SDK"
|
||||
|
||||
CONF_ENABLE_CONVERSATION_AGENT: Final = "enable_conversation_agent"
|
||||
CONF_LANGUAGE_CODE: Final = "language_code"
|
||||
|
||||
DATA_MEM_STORAGE: Final = "mem_storage"
|
||||
DATA_SESSION: Final = "session"
|
||||
|
||||
# https://developers.google.com/assistant/sdk/reference/rpc/languages
|
||||
SUPPORTED_LANGUAGE_CODES: Final = [
|
||||
"de-DE",
|
||||
@ -24,5 +28,3 @@ SUPPORTED_LANGUAGE_CODES: Final = [
|
||||
"ko-KR",
|
||||
"pt-BR",
|
||||
]
|
||||
|
||||
CONF_ENABLE_CONVERSATION_AGENT: Final = "enable_conversation_agent"
|
||||
|
@ -1,18 +1,38 @@
|
||||
"""Helper classes for Google Assistant SDK integration."""
|
||||
from __future__ import annotations
|
||||
|
||||
from http import HTTPStatus
|
||||
import logging
|
||||
from typing import Any
|
||||
import uuid
|
||||
|
||||
import aiohttp
|
||||
from aiohttp import web
|
||||
from gassist_text import TextAssistant
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
from homeassistant.components.media_player import (
|
||||
ATTR_MEDIA_ANNOUNCE,
|
||||
ATTR_MEDIA_CONTENT_ID,
|
||||
ATTR_MEDIA_CONTENT_TYPE,
|
||||
DOMAIN as DOMAIN_MP,
|
||||
SERVICE_PLAY_MEDIA,
|
||||
MediaType,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import CONF_ACCESS_TOKEN
|
||||
from homeassistant.const import ATTR_ENTITY_ID, CONF_ACCESS_TOKEN
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers.config_entry_oauth2_flow import OAuth2Session
|
||||
from homeassistant.helpers.event import async_call_later
|
||||
|
||||
from .const import CONF_LANGUAGE_CODE, DOMAIN, SUPPORTED_LANGUAGE_CODES
|
||||
from .const import (
|
||||
CONF_LANGUAGE_CODE,
|
||||
DATA_MEM_STORAGE,
|
||||
DATA_SESSION,
|
||||
DOMAIN,
|
||||
SUPPORTED_LANGUAGE_CODES,
|
||||
)
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
@ -28,12 +48,14 @@ DEFAULT_LANGUAGE_CODES = {
|
||||
}
|
||||
|
||||
|
||||
async def async_send_text_commands(commands: list[str], hass: HomeAssistant) -> None:
|
||||
async def async_send_text_commands(
|
||||
hass: HomeAssistant, commands: list[str], media_players: list[str] | None = None
|
||||
) -> None:
|
||||
"""Send text commands to Google Assistant Service."""
|
||||
# There can only be 1 entry (config_flow has single_instance_allowed)
|
||||
entry: ConfigEntry = hass.config_entries.async_entries(DOMAIN)[0]
|
||||
|
||||
session: OAuth2Session = hass.data[DOMAIN].get(entry.entry_id)
|
||||
session: OAuth2Session = hass.data[DOMAIN][entry.entry_id][DATA_SESSION]
|
||||
try:
|
||||
await session.async_ensure_token_valid()
|
||||
except aiohttp.ClientResponseError as err:
|
||||
@ -43,10 +65,32 @@ async def async_send_text_commands(commands: list[str], hass: HomeAssistant) ->
|
||||
|
||||
credentials = Credentials(session.token[CONF_ACCESS_TOKEN])
|
||||
language_code = entry.options.get(CONF_LANGUAGE_CODE, default_language_code(hass))
|
||||
with TextAssistant(credentials, language_code) as assistant:
|
||||
with TextAssistant(
|
||||
credentials, language_code, audio_out=bool(media_players)
|
||||
) as assistant:
|
||||
for command in commands:
|
||||
text_response = assistant.assist(command)[0]
|
||||
resp = assistant.assist(command)
|
||||
text_response = resp[0]
|
||||
_LOGGER.debug("command: %s\nresponse: %s", command, text_response)
|
||||
audio_response = resp[2]
|
||||
if media_players and audio_response:
|
||||
mem_storage: InMemoryStorage = hass.data[DOMAIN][entry.entry_id][
|
||||
DATA_MEM_STORAGE
|
||||
]
|
||||
audio_url = GoogleAssistantSDKAudioView.url.format(
|
||||
filename=mem_storage.store_and_get_identifier(audio_response)
|
||||
)
|
||||
await hass.services.async_call(
|
||||
DOMAIN_MP,
|
||||
SERVICE_PLAY_MEDIA,
|
||||
{
|
||||
ATTR_ENTITY_ID: media_players,
|
||||
ATTR_MEDIA_CONTENT_ID: audio_url,
|
||||
ATTR_MEDIA_CONTENT_TYPE: MediaType.MUSIC,
|
||||
ATTR_MEDIA_ANNOUNCE: True,
|
||||
},
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
|
||||
def default_language_code(hass: HomeAssistant):
|
||||
@ -55,3 +99,53 @@ def default_language_code(hass: HomeAssistant):
|
||||
if language_code in SUPPORTED_LANGUAGE_CODES:
|
||||
return language_code
|
||||
return DEFAULT_LANGUAGE_CODES.get(hass.config.language, "en-US")
|
||||
|
||||
|
||||
class InMemoryStorage:
|
||||
"""Temporarily store and retrieve data from in memory storage."""
|
||||
|
||||
def __init__(self, hass: HomeAssistant) -> None:
|
||||
"""Initialize InMemoryStorage."""
|
||||
self.hass: HomeAssistant = hass
|
||||
self.mem: dict[str, bytes] = {}
|
||||
|
||||
def store_and_get_identifier(self, data: bytes) -> str:
|
||||
"""
|
||||
Temporarily store data and return identifier to be able to retrieve it.
|
||||
|
||||
Data expires after 5 minutes.
|
||||
"""
|
||||
identifier: str = uuid.uuid1().hex
|
||||
self.mem[identifier] = data
|
||||
|
||||
def async_remove_from_mem(*_: Any) -> None:
|
||||
"""Cleanup memory."""
|
||||
self.mem.pop(identifier, None)
|
||||
|
||||
# Remove the entry from memory 5 minutes later
|
||||
async_call_later(self.hass, 5 * 60, async_remove_from_mem)
|
||||
|
||||
return identifier
|
||||
|
||||
def retrieve(self, identifier: str) -> bytes | None:
|
||||
"""Retrieve previously stored data."""
|
||||
return self.mem.get(identifier)
|
||||
|
||||
|
||||
class GoogleAssistantSDKAudioView(HomeAssistantView):
|
||||
"""Google Assistant SDK view to serve audio responses."""
|
||||
|
||||
requires_auth = True
|
||||
url = "/api/google_assistant_sdk/audio/{filename}"
|
||||
name = "api:google_assistant_sdk:audio"
|
||||
|
||||
def __init__(self, mem_storage: InMemoryStorage) -> None:
|
||||
"""Initialize GoogleAssistantSDKView."""
|
||||
self.mem_storage: InMemoryStorage = mem_storage
|
||||
|
||||
async def get(self, request: web.Request, filename: str) -> web.Response:
|
||||
"""Start a get request."""
|
||||
audio = self.mem_storage.retrieve(filename)
|
||||
if not audio:
|
||||
return web.Response(status=HTTPStatus.NOT_FOUND)
|
||||
return web.Response(body=audio, content_type="audio/mpeg")
|
||||
|
@ -2,9 +2,9 @@
|
||||
"domain": "google_assistant_sdk",
|
||||
"name": "Google Assistant SDK",
|
||||
"config_flow": true,
|
||||
"dependencies": ["application_credentials"],
|
||||
"dependencies": ["application_credentials", "http"],
|
||||
"documentation": "https://www.home-assistant.io/integrations/google_assistant_sdk/",
|
||||
"requirements": ["gassist-text==0.0.7"],
|
||||
"requirements": ["gassist-text==0.0.8"],
|
||||
"codeowners": ["@tronikos"],
|
||||
"iot_class": "cloud_polling",
|
||||
"integration_type": "service"
|
||||
|
@ -70,4 +70,4 @@ class BroadcastNotificationService(BaseNotificationService):
|
||||
commands.append(
|
||||
broadcast_commands(language_code)[1].format(message, target)
|
||||
)
|
||||
await async_send_text_commands(commands, self.hass)
|
||||
await async_send_text_commands(self.hass, commands)
|
||||
|
@ -8,3 +8,10 @@ send_text_command:
|
||||
example: turn off kitchen TV
|
||||
selector:
|
||||
text:
|
||||
media_player:
|
||||
name: Media Player Entity
|
||||
description: Name(s) of media player entities to play response on
|
||||
example: media_player.living_room_speaker
|
||||
selector:
|
||||
entity:
|
||||
domain: media_player
|
||||
|
@ -754,7 +754,7 @@ fritzconnection==1.10.3
|
||||
gTTS==2.2.4
|
||||
|
||||
# homeassistant.components.google_assistant_sdk
|
||||
gassist-text==0.0.7
|
||||
gassist-text==0.0.8
|
||||
|
||||
# homeassistant.components.google
|
||||
gcal-sync==4.1.2
|
||||
|
@ -573,7 +573,7 @@ fritzconnection==1.10.3
|
||||
gTTS==2.2.4
|
||||
|
||||
# homeassistant.components.google_assistant_sdk
|
||||
gassist-text==0.0.7
|
||||
gassist-text==0.0.8
|
||||
|
||||
# homeassistant.components.google
|
||||
gcal-sync==4.1.2
|
||||
|
@ -1,4 +1,5 @@
|
||||
"""Tests for Google Assistant SDK."""
|
||||
from datetime import timedelta
|
||||
import http
|
||||
import time
|
||||
from unittest.mock import call, patch
|
||||
@ -10,12 +11,22 @@ from homeassistant.components.google_assistant_sdk import DOMAIN
|
||||
from homeassistant.config_entries import ConfigEntryState
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.util.dt import utcnow
|
||||
|
||||
from .conftest import ComponentSetup, ExpectedCredentials
|
||||
|
||||
from tests.common import async_fire_time_changed, async_mock_service
|
||||
from tests.test_util.aiohttp import AiohttpClientMocker
|
||||
|
||||
|
||||
async def fetch_api_url(hass_client, url):
|
||||
"""Fetch an API URL and return HTTP status and contents."""
|
||||
client = await hass_client()
|
||||
response = await client.get(url)
|
||||
contents = await response.read()
|
||||
return response.status, contents
|
||||
|
||||
|
||||
async def test_setup_success(
|
||||
hass: HomeAssistant, setup_integration: ComponentSetup
|
||||
) -> None:
|
||||
@ -129,7 +140,7 @@ async def test_send_text_command(
|
||||
blocking=True,
|
||||
)
|
||||
mock_text_assistant.assert_called_once_with(
|
||||
ExpectedCredentials(), expected_language_code
|
||||
ExpectedCredentials(), expected_language_code, audio_out=False
|
||||
)
|
||||
mock_text_assistant.assert_has_calls([call().__enter__().assist(command)])
|
||||
|
||||
@ -180,6 +191,88 @@ async def test_send_text_command_expired_token_refresh_failure(
|
||||
assert any(entry.async_get_active_flows(hass, {"reauth"})) == requires_reauth
|
||||
|
||||
|
||||
async def test_send_text_command_media_player(
|
||||
hass: HomeAssistant, setup_integration: ComponentSetup, hass_client
|
||||
) -> None:
|
||||
"""Test send_text_command with media_player."""
|
||||
await setup_integration()
|
||||
|
||||
play_media_calls = async_mock_service(hass, "media_player", "play_media")
|
||||
|
||||
command = "tell me a joke"
|
||||
media_player = "media_player.office_speaker"
|
||||
audio_response1 = b"joke1 audio response bytes"
|
||||
audio_response2 = b"joke2 audio response bytes"
|
||||
with patch(
|
||||
"homeassistant.components.google_assistant_sdk.helpers.TextAssistant.assist",
|
||||
side_effect=[
|
||||
("joke1 text", None, audio_response1),
|
||||
("joke2 text", None, audio_response2),
|
||||
],
|
||||
) as mock_assist_call:
|
||||
# Run the same command twice, getting different audio response each time.
|
||||
await hass.services.async_call(
|
||||
DOMAIN,
|
||||
"send_text_command",
|
||||
{
|
||||
"command": command,
|
||||
"media_player": media_player,
|
||||
},
|
||||
blocking=True,
|
||||
)
|
||||
await hass.services.async_call(
|
||||
DOMAIN,
|
||||
"send_text_command",
|
||||
{
|
||||
"command": command,
|
||||
"media_player": media_player,
|
||||
},
|
||||
blocking=True,
|
||||
)
|
||||
|
||||
mock_assist_call.assert_has_calls([call(command), call(command)])
|
||||
assert len(play_media_calls) == 2
|
||||
for play_media_call in play_media_calls:
|
||||
assert play_media_call.data["entity_id"] == [media_player]
|
||||
assert play_media_call.data["media_content_id"].startswith(
|
||||
"/api/google_assistant_sdk/audio/"
|
||||
)
|
||||
|
||||
audio_url1 = play_media_calls[0].data["media_content_id"]
|
||||
audio_url2 = play_media_calls[1].data["media_content_id"]
|
||||
assert audio_url1 != audio_url2
|
||||
|
||||
# Assert that both audio responses can be served
|
||||
status, response = await fetch_api_url(hass_client, audio_url1)
|
||||
assert status == http.HTTPStatus.OK
|
||||
assert response == audio_response1
|
||||
status, response = await fetch_api_url(hass_client, audio_url2)
|
||||
assert status == http.HTTPStatus.OK
|
||||
assert response == audio_response2
|
||||
|
||||
# Assert a nonexistent URL returns 404
|
||||
status, _ = await fetch_api_url(
|
||||
hass_client, "/api/google_assistant_sdk/audio/nonexistent"
|
||||
)
|
||||
assert status == http.HTTPStatus.NOT_FOUND
|
||||
|
||||
# Assert that both audio responses can still be served before the 5 minutes expiration
|
||||
async_fire_time_changed(hass, utcnow() + timedelta(minutes=4))
|
||||
status, response = await fetch_api_url(hass_client, audio_url1)
|
||||
assert status == http.HTTPStatus.OK
|
||||
assert response == audio_response1
|
||||
status, response = await fetch_api_url(hass_client, audio_url2)
|
||||
assert status == http.HTTPStatus.OK
|
||||
assert response == audio_response2
|
||||
|
||||
# Assert that they cannot be served after the 5 minutes expiration
|
||||
async_fire_time_changed(hass, utcnow() + timedelta(minutes=6))
|
||||
status, response = await fetch_api_url(hass_client, audio_url1)
|
||||
assert status == http.HTTPStatus.NOT_FOUND
|
||||
status, response = await fetch_api_url(hass_client, audio_url2)
|
||||
assert status == http.HTTPStatus.NOT_FOUND
|
||||
|
||||
|
||||
async def test_conversation_agent(
|
||||
hass: HomeAssistant,
|
||||
setup_integration: ComponentSetup,
|
||||
|
@ -44,7 +44,9 @@ async def test_broadcast_no_targets(
|
||||
{notify.ATTR_MESSAGE: message},
|
||||
)
|
||||
await hass.async_block_till_done()
|
||||
mock_text_assistant.assert_called_once_with(ExpectedCredentials(), language_code)
|
||||
mock_text_assistant.assert_called_once_with(
|
||||
ExpectedCredentials(), language_code, audio_out=False
|
||||
)
|
||||
mock_text_assistant.assert_has_calls([call().__enter__().assist(expected_command)])
|
||||
|
||||
|
||||
@ -84,7 +86,7 @@ async def test_broadcast_one_target(
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.google_assistant_sdk.helpers.TextAssistant.assist",
|
||||
return_value=["text_response", None],
|
||||
return_value=("text_response", None, b""),
|
||||
) as mock_assist_call:
|
||||
await hass.services.async_call(
|
||||
notify.DOMAIN,
|
||||
@ -108,7 +110,7 @@ async def test_broadcast_two_targets(
|
||||
expected_command2 = "broadcast to master bedroom time for dinner"
|
||||
with patch(
|
||||
"homeassistant.components.google_assistant_sdk.helpers.TextAssistant.assist",
|
||||
return_value=["text_response", None],
|
||||
return_value=("text_response", None, b""),
|
||||
) as mock_assist_call:
|
||||
await hass.services.async_call(
|
||||
notify.DOMAIN,
|
||||
@ -129,7 +131,7 @@ async def test_broadcast_empty_message(
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.google_assistant_sdk.helpers.TextAssistant.assist",
|
||||
return_value=["text_response", None],
|
||||
return_value=("text_response", None, b""),
|
||||
) as mock_assist_call:
|
||||
await hass.services.async_call(
|
||||
notify.DOMAIN,
|
||||
|
Loading…
x
Reference in New Issue
Block a user