Change assist satellite announce method signature (#126299)

This commit is contained in:
Paulus Schoutsen 2024-09-20 09:09:37 -04:00 committed by GitHub
parent 41ffa8d6db
commit 604c848dec
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 52 additions and 17 deletions

View File

@ -12,6 +12,7 @@ from homeassistant.helpers.typing import ConfigType
from .const import DOMAIN, DOMAIN_DATA, AssistSatelliteEntityFeature from .const import DOMAIN, DOMAIN_DATA, AssistSatelliteEntityFeature
from .entity import ( from .entity import (
AssistSatelliteAnnouncement,
AssistSatelliteConfiguration, AssistSatelliteConfiguration,
AssistSatelliteEntity, AssistSatelliteEntity,
AssistSatelliteEntityDescription, AssistSatelliteEntityDescription,
@ -22,6 +23,7 @@ from .websocket_api import async_register_websocket_api
__all__ = [ __all__ = [
"DOMAIN", "DOMAIN",
"AssistSatelliteAnnouncement",
"AssistSatelliteEntity", "AssistSatelliteEntity",
"AssistSatelliteConfiguration", "AssistSatelliteConfiguration",
"AssistSatelliteEntityDescription", "AssistSatelliteEntityDescription",

View File

@ -8,7 +8,7 @@ from dataclasses import dataclass
from enum import StrEnum from enum import StrEnum
import logging import logging
import time import time
from typing import Any, Final, final from typing import Any, Final, Literal, final
from homeassistant.components import media_source, stt, tts from homeassistant.components import media_source, stt, tts
from homeassistant.components.assist_pipeline import ( from homeassistant.components.assist_pipeline import (
@ -86,6 +86,19 @@ class AssistSatelliteConfiguration:
"""Maximum number of simultaneous wake words allowed (0 for no limit).""" """Maximum number of simultaneous wake words allowed (0 for no limit)."""
@dataclass
class AssistSatelliteAnnouncement:
"""Announcement to be made."""
message: str
"""Message to be spoken."""
media_id: str
"""Media ID to be played."""
media_id_source: Literal["url", "media_id", "tts"]
class AssistSatelliteEntity(entity.Entity): class AssistSatelliteEntity(entity.Entity):
"""Entity encapsulating the state and functionality of an Assist satellite.""" """Entity encapsulating the state and functionality of an Assist satellite."""
@ -174,10 +187,13 @@ class AssistSatelliteEntity(entity.Entity):
""" """
await self._cancel_running_pipeline() await self._cancel_running_pipeline()
media_id_source: Literal["url", "media_id", "tts"] | None = None
if message is None: if message is None:
message = "" message = ""
if not media_id: if not media_id:
media_id_source = "tts"
# Synthesize audio and get URL # Synthesize audio and get URL
pipeline_id = self._resolve_pipeline() pipeline_id = self._resolve_pipeline()
pipeline = async_get_pipeline(self.hass, pipeline_id) pipeline = async_get_pipeline(self.hass, pipeline_id)
@ -198,6 +214,8 @@ class AssistSatelliteEntity(entity.Entity):
) )
if media_source.is_media_source_id(media_id): if media_source.is_media_source_id(media_id):
if not media_id_source:
media_id_source = "media_id"
media = await media_source.async_resolve_media( media = await media_source.async_resolve_media(
self.hass, self.hass,
media_id, media_id,
@ -205,6 +223,9 @@ class AssistSatelliteEntity(entity.Entity):
) )
media_id = media.url media_id = media.url
if not media_id_source:
media_id_source = "url"
# Resolve to full URL # Resolve to full URL
media_id = async_process_play_media_url(self.hass, media_id) media_id = async_process_play_media_url(self.hass, media_id)
@ -216,12 +237,14 @@ class AssistSatelliteEntity(entity.Entity):
try: try:
# Block until announcement is finished # Block until announcement is finished
await self.async_announce(message, media_id) await self.async_announce(
AssistSatelliteAnnouncement(message, media_id, media_id_source)
)
finally: finally:
self._is_announcing = False self._is_announcing = False
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD) self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
async def async_announce(self, message: str, media_id: str) -> None: async def async_announce(self, announcement: AssistSatelliteAnnouncement) -> None:
"""Announce media on the satellite. """Announce media on the satellite.
Should block until the announcement is done playing. Should block until the announcement is done playing.

View File

@ -313,18 +313,20 @@ class EsphomeAssistSatellite(
self.cli.send_voice_assistant_event(event_type, data_to_send) self.cli.send_voice_assistant_event(event_type, data_to_send)
async def async_announce(self, message: str, media_id: str) -> None: async def async_announce(
self, announcement: assist_satellite.AssistSatelliteAnnouncement
) -> None:
"""Announce media on the satellite. """Announce media on the satellite.
Should block until the announcement is done playing. Should block until the announcement is done playing.
""" """
_LOGGER.debug( _LOGGER.debug(
"Waiting for announcement to finished (message=%s, media_id=%s)", "Waiting for announcement to finished (message=%s, media_id=%s)",
message, announcement.message,
media_id, announcement.media_id,
) )
await self.cli.send_voice_assistant_announcement_await_response( await self.cli.send_voice_assistant_announcement_await_response(
media_id, _ANNOUNCEMENT_TIMEOUT_SEC, message announcement.media_id, _ANNOUNCEMENT_TIMEOUT_SEC, announcement.message
) )
async def handle_pipeline_start( async def handle_pipeline_start(

View File

@ -8,6 +8,7 @@ import pytest
from homeassistant.components.assist_pipeline import PipelineEvent from homeassistant.components.assist_pipeline import PipelineEvent
from homeassistant.components.assist_satellite import ( from homeassistant.components.assist_satellite import (
DOMAIN as AS_DOMAIN, DOMAIN as AS_DOMAIN,
AssistSatelliteAnnouncement,
AssistSatelliteConfiguration, AssistSatelliteConfiguration,
AssistSatelliteEntity, AssistSatelliteEntity,
AssistSatelliteEntityFeature, AssistSatelliteEntityFeature,
@ -63,9 +64,9 @@ class MockAssistSatellite(AssistSatelliteEntity):
"""Handle pipeline events.""" """Handle pipeline events."""
self.events.append(event) self.events.append(event)
async def async_announce(self, message: str, media_id: str) -> None: async def async_announce(self, announcement: AssistSatelliteAnnouncement) -> None:
"""Announce media on a device.""" """Announce media on a device."""
self.announcements.append((message, media_id)) self.announcements.append(announcement)
@callback @callback
def async_get_configuration(self) -> AssistSatelliteConfiguration: def async_get_configuration(self) -> AssistSatelliteConfiguration:

View File

@ -17,7 +17,10 @@ from homeassistant.components.assist_pipeline import (
async_update_pipeline, async_update_pipeline,
vad, vad,
) )
from homeassistant.components.assist_satellite import SatelliteBusyError from homeassistant.components.assist_satellite import (
AssistSatelliteAnnouncement,
SatelliteBusyError,
)
from homeassistant.components.assist_satellite.entity import AssistSatelliteState from homeassistant.components.assist_satellite.entity import AssistSatelliteState
from homeassistant.components.media_source import PlayMedia from homeassistant.components.media_source import PlayMedia
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
@ -159,18 +162,22 @@ async def test_new_pipeline_cancels_pipeline(
[ [
( (
{"message": "Hello"}, {"message": "Hello"},
("Hello", "https://www.home-assistant.io/resolved.mp3"), AssistSatelliteAnnouncement(
"Hello", "https://www.home-assistant.io/resolved.mp3", "tts"
),
), ),
( (
{ {
"message": "Hello", "message": "Hello",
"media_id": "http://example.com/bla.mp3", "media_id": "media-source://bla",
}, },
("Hello", "http://example.com/bla.mp3"), AssistSatelliteAnnouncement(
"Hello", "https://www.home-assistant.io/resolved.mp3", "media_id"
),
), ),
( (
{"media_id": "http://example.com/bla.mp3"}, {"media_id": "http://example.com/bla.mp3"},
("", "http://example.com/bla.mp3"), AssistSatelliteAnnouncement("", "http://example.com/bla.mp3", "url"),
), ),
], ],
) )
@ -195,10 +202,10 @@ async def test_announce(
original_announce = entity.async_announce original_announce = entity.async_announce
announce_started = asyncio.Event() announce_started = asyncio.Event()
async def async_announce(message, media_id): async def async_announce(announcement):
# Verify state change # Verify state change
assert entity.state == AssistSatelliteState.RESPONDING assert entity.state == AssistSatelliteState.RESPONDING
await original_announce(message, media_id) await original_announce(announcement)
announce_started.set() announce_started.set()
def tts_generate_media_source_id( def tts_generate_media_source_id(
@ -249,7 +256,7 @@ async def test_announce_busy(
announce_started = asyncio.Event() announce_started = asyncio.Event()
got_error = asyncio.Event() got_error = asyncio.Event()
async def async_announce(message, media_id): async def async_announce(announcement):
announce_started.set() announce_started.set()
# Block so we can do another announcement # Block so we can do another announcement