mirror of
https://github.com/home-assistant/core.git
synced 2025-07-25 22:27:07 +00:00
Change assist satellite announce method signature (#126299)
This commit is contained in:
parent
41ffa8d6db
commit
604c848dec
@ -12,6 +12,7 @@ from homeassistant.helpers.typing import ConfigType
|
|||||||
|
|
||||||
from .const import DOMAIN, DOMAIN_DATA, AssistSatelliteEntityFeature
|
from .const import DOMAIN, DOMAIN_DATA, AssistSatelliteEntityFeature
|
||||||
from .entity import (
|
from .entity import (
|
||||||
|
AssistSatelliteAnnouncement,
|
||||||
AssistSatelliteConfiguration,
|
AssistSatelliteConfiguration,
|
||||||
AssistSatelliteEntity,
|
AssistSatelliteEntity,
|
||||||
AssistSatelliteEntityDescription,
|
AssistSatelliteEntityDescription,
|
||||||
@ -22,6 +23,7 @@ from .websocket_api import async_register_websocket_api
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DOMAIN",
|
"DOMAIN",
|
||||||
|
"AssistSatelliteAnnouncement",
|
||||||
"AssistSatelliteEntity",
|
"AssistSatelliteEntity",
|
||||||
"AssistSatelliteConfiguration",
|
"AssistSatelliteConfiguration",
|
||||||
"AssistSatelliteEntityDescription",
|
"AssistSatelliteEntityDescription",
|
||||||
|
@ -8,7 +8,7 @@ from dataclasses import dataclass
|
|||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Any, Final, final
|
from typing import Any, Final, Literal, final
|
||||||
|
|
||||||
from homeassistant.components import media_source, stt, tts
|
from homeassistant.components import media_source, stt, tts
|
||||||
from homeassistant.components.assist_pipeline import (
|
from homeassistant.components.assist_pipeline import (
|
||||||
@ -86,6 +86,19 @@ class AssistSatelliteConfiguration:
|
|||||||
"""Maximum number of simultaneous wake words allowed (0 for no limit)."""
|
"""Maximum number of simultaneous wake words allowed (0 for no limit)."""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class AssistSatelliteAnnouncement:
|
||||||
|
"""Announcement to be made."""
|
||||||
|
|
||||||
|
message: str
|
||||||
|
"""Message to be spoken."""
|
||||||
|
|
||||||
|
media_id: str
|
||||||
|
"""Media ID to be played."""
|
||||||
|
|
||||||
|
media_id_source: Literal["url", "media_id", "tts"]
|
||||||
|
|
||||||
|
|
||||||
class AssistSatelliteEntity(entity.Entity):
|
class AssistSatelliteEntity(entity.Entity):
|
||||||
"""Entity encapsulating the state and functionality of an Assist satellite."""
|
"""Entity encapsulating the state and functionality of an Assist satellite."""
|
||||||
|
|
||||||
@ -174,10 +187,13 @@ class AssistSatelliteEntity(entity.Entity):
|
|||||||
"""
|
"""
|
||||||
await self._cancel_running_pipeline()
|
await self._cancel_running_pipeline()
|
||||||
|
|
||||||
|
media_id_source: Literal["url", "media_id", "tts"] | None = None
|
||||||
|
|
||||||
if message is None:
|
if message is None:
|
||||||
message = ""
|
message = ""
|
||||||
|
|
||||||
if not media_id:
|
if not media_id:
|
||||||
|
media_id_source = "tts"
|
||||||
# Synthesize audio and get URL
|
# Synthesize audio and get URL
|
||||||
pipeline_id = self._resolve_pipeline()
|
pipeline_id = self._resolve_pipeline()
|
||||||
pipeline = async_get_pipeline(self.hass, pipeline_id)
|
pipeline = async_get_pipeline(self.hass, pipeline_id)
|
||||||
@ -198,6 +214,8 @@ class AssistSatelliteEntity(entity.Entity):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if media_source.is_media_source_id(media_id):
|
if media_source.is_media_source_id(media_id):
|
||||||
|
if not media_id_source:
|
||||||
|
media_id_source = "media_id"
|
||||||
media = await media_source.async_resolve_media(
|
media = await media_source.async_resolve_media(
|
||||||
self.hass,
|
self.hass,
|
||||||
media_id,
|
media_id,
|
||||||
@ -205,6 +223,9 @@ class AssistSatelliteEntity(entity.Entity):
|
|||||||
)
|
)
|
||||||
media_id = media.url
|
media_id = media.url
|
||||||
|
|
||||||
|
if not media_id_source:
|
||||||
|
media_id_source = "url"
|
||||||
|
|
||||||
# Resolve to full URL
|
# Resolve to full URL
|
||||||
media_id = async_process_play_media_url(self.hass, media_id)
|
media_id = async_process_play_media_url(self.hass, media_id)
|
||||||
|
|
||||||
@ -216,12 +237,14 @@ class AssistSatelliteEntity(entity.Entity):
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
# Block until announcement is finished
|
# Block until announcement is finished
|
||||||
await self.async_announce(message, media_id)
|
await self.async_announce(
|
||||||
|
AssistSatelliteAnnouncement(message, media_id, media_id_source)
|
||||||
|
)
|
||||||
finally:
|
finally:
|
||||||
self._is_announcing = False
|
self._is_announcing = False
|
||||||
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
|
self._set_state(AssistSatelliteState.LISTENING_WAKE_WORD)
|
||||||
|
|
||||||
async def async_announce(self, message: str, media_id: str) -> None:
|
async def async_announce(self, announcement: AssistSatelliteAnnouncement) -> None:
|
||||||
"""Announce media on the satellite.
|
"""Announce media on the satellite.
|
||||||
|
|
||||||
Should block until the announcement is done playing.
|
Should block until the announcement is done playing.
|
||||||
|
@ -313,18 +313,20 @@ class EsphomeAssistSatellite(
|
|||||||
|
|
||||||
self.cli.send_voice_assistant_event(event_type, data_to_send)
|
self.cli.send_voice_assistant_event(event_type, data_to_send)
|
||||||
|
|
||||||
async def async_announce(self, message: str, media_id: str) -> None:
|
async def async_announce(
|
||||||
|
self, announcement: assist_satellite.AssistSatelliteAnnouncement
|
||||||
|
) -> None:
|
||||||
"""Announce media on the satellite.
|
"""Announce media on the satellite.
|
||||||
|
|
||||||
Should block until the announcement is done playing.
|
Should block until the announcement is done playing.
|
||||||
"""
|
"""
|
||||||
_LOGGER.debug(
|
_LOGGER.debug(
|
||||||
"Waiting for announcement to finished (message=%s, media_id=%s)",
|
"Waiting for announcement to finished (message=%s, media_id=%s)",
|
||||||
message,
|
announcement.message,
|
||||||
media_id,
|
announcement.media_id,
|
||||||
)
|
)
|
||||||
await self.cli.send_voice_assistant_announcement_await_response(
|
await self.cli.send_voice_assistant_announcement_await_response(
|
||||||
media_id, _ANNOUNCEMENT_TIMEOUT_SEC, message
|
announcement.media_id, _ANNOUNCEMENT_TIMEOUT_SEC, announcement.message
|
||||||
)
|
)
|
||||||
|
|
||||||
async def handle_pipeline_start(
|
async def handle_pipeline_start(
|
||||||
|
@ -8,6 +8,7 @@ import pytest
|
|||||||
from homeassistant.components.assist_pipeline import PipelineEvent
|
from homeassistant.components.assist_pipeline import PipelineEvent
|
||||||
from homeassistant.components.assist_satellite import (
|
from homeassistant.components.assist_satellite import (
|
||||||
DOMAIN as AS_DOMAIN,
|
DOMAIN as AS_DOMAIN,
|
||||||
|
AssistSatelliteAnnouncement,
|
||||||
AssistSatelliteConfiguration,
|
AssistSatelliteConfiguration,
|
||||||
AssistSatelliteEntity,
|
AssistSatelliteEntity,
|
||||||
AssistSatelliteEntityFeature,
|
AssistSatelliteEntityFeature,
|
||||||
@ -63,9 +64,9 @@ class MockAssistSatellite(AssistSatelliteEntity):
|
|||||||
"""Handle pipeline events."""
|
"""Handle pipeline events."""
|
||||||
self.events.append(event)
|
self.events.append(event)
|
||||||
|
|
||||||
async def async_announce(self, message: str, media_id: str) -> None:
|
async def async_announce(self, announcement: AssistSatelliteAnnouncement) -> None:
|
||||||
"""Announce media on a device."""
|
"""Announce media on a device."""
|
||||||
self.announcements.append((message, media_id))
|
self.announcements.append(announcement)
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def async_get_configuration(self) -> AssistSatelliteConfiguration:
|
def async_get_configuration(self) -> AssistSatelliteConfiguration:
|
||||||
|
@ -17,7 +17,10 @@ from homeassistant.components.assist_pipeline import (
|
|||||||
async_update_pipeline,
|
async_update_pipeline,
|
||||||
vad,
|
vad,
|
||||||
)
|
)
|
||||||
from homeassistant.components.assist_satellite import SatelliteBusyError
|
from homeassistant.components.assist_satellite import (
|
||||||
|
AssistSatelliteAnnouncement,
|
||||||
|
SatelliteBusyError,
|
||||||
|
)
|
||||||
from homeassistant.components.assist_satellite.entity import AssistSatelliteState
|
from homeassistant.components.assist_satellite.entity import AssistSatelliteState
|
||||||
from homeassistant.components.media_source import PlayMedia
|
from homeassistant.components.media_source import PlayMedia
|
||||||
from homeassistant.config_entries import ConfigEntry
|
from homeassistant.config_entries import ConfigEntry
|
||||||
@ -159,18 +162,22 @@ async def test_new_pipeline_cancels_pipeline(
|
|||||||
[
|
[
|
||||||
(
|
(
|
||||||
{"message": "Hello"},
|
{"message": "Hello"},
|
||||||
("Hello", "https://www.home-assistant.io/resolved.mp3"),
|
AssistSatelliteAnnouncement(
|
||||||
|
"Hello", "https://www.home-assistant.io/resolved.mp3", "tts"
|
||||||
|
),
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
{
|
{
|
||||||
"message": "Hello",
|
"message": "Hello",
|
||||||
"media_id": "http://example.com/bla.mp3",
|
"media_id": "media-source://bla",
|
||||||
},
|
},
|
||||||
("Hello", "http://example.com/bla.mp3"),
|
AssistSatelliteAnnouncement(
|
||||||
|
"Hello", "https://www.home-assistant.io/resolved.mp3", "media_id"
|
||||||
|
),
|
||||||
),
|
),
|
||||||
(
|
(
|
||||||
{"media_id": "http://example.com/bla.mp3"},
|
{"media_id": "http://example.com/bla.mp3"},
|
||||||
("", "http://example.com/bla.mp3"),
|
AssistSatelliteAnnouncement("", "http://example.com/bla.mp3", "url"),
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
@ -195,10 +202,10 @@ async def test_announce(
|
|||||||
original_announce = entity.async_announce
|
original_announce = entity.async_announce
|
||||||
announce_started = asyncio.Event()
|
announce_started = asyncio.Event()
|
||||||
|
|
||||||
async def async_announce(message, media_id):
|
async def async_announce(announcement):
|
||||||
# Verify state change
|
# Verify state change
|
||||||
assert entity.state == AssistSatelliteState.RESPONDING
|
assert entity.state == AssistSatelliteState.RESPONDING
|
||||||
await original_announce(message, media_id)
|
await original_announce(announcement)
|
||||||
announce_started.set()
|
announce_started.set()
|
||||||
|
|
||||||
def tts_generate_media_source_id(
|
def tts_generate_media_source_id(
|
||||||
@ -249,7 +256,7 @@ async def test_announce_busy(
|
|||||||
announce_started = asyncio.Event()
|
announce_started = asyncio.Event()
|
||||||
got_error = asyncio.Event()
|
got_error = asyncio.Event()
|
||||||
|
|
||||||
async def async_announce(message, media_id):
|
async def async_announce(announcement):
|
||||||
announce_started.set()
|
announce_started.set()
|
||||||
|
|
||||||
# Block so we can do another announcement
|
# Block so we can do another announcement
|
||||||
|
Loading…
x
Reference in New Issue
Block a user