Assist satellite to use TTS tokens for announcements (#140336)

* Migrate Assist Satellite to use token

* Fix tests

* Fix tests
This commit is contained in:
Paulus Schoutsen 2025-03-13 09:36:38 -04:00 committed by GitHub
parent e710d3699c
commit f32bb1a318
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 113 additions and 40 deletions

View File

@ -23,9 +23,6 @@ from homeassistant.components.assist_pipeline import (
vad,
)
from homeassistant.components.media_player import async_process_play_media_url
from homeassistant.components.tts import (
generate_media_source_id as tts_generate_media_source_id,
)
from homeassistant.core import Context, callback
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import chat_session, entity
@ -98,6 +95,9 @@ class AssistSatelliteAnnouncement:
original_media_id: str
"""The raw media ID before processing."""
tts_token: str | None
"""The TTS token of the media."""
media_id_source: Literal["url", "media_id", "tts"]
"""Source of the media ID."""
@ -474,6 +474,7 @@ class AssistSatelliteEntity(entity.Entity):
) -> AssistSatelliteAnnouncement:
"""Resolve the media ID."""
media_id_source: Literal["url", "media_id", "tts"] | None = None
tts_token: str | None = None
if media_id:
original_media_id = media_id
@ -484,6 +485,10 @@ class AssistSatelliteEntity(entity.Entity):
pipeline_id = self._resolve_pipeline()
pipeline = async_get_pipeline(self.hass, pipeline_id)
engine = tts.async_resolve_engine(self.hass, pipeline.tts_engine)
if engine is None:
raise HomeAssistantError(f"TTS engine {pipeline.tts_engine} not found")
tts_options: dict[str, Any] = {}
if pipeline.tts_voice is not None:
tts_options[tts.ATTR_VOICE] = pipeline.tts_voice
@ -491,14 +496,23 @@ class AssistSatelliteEntity(entity.Entity):
if self.tts_options is not None:
tts_options.update(self.tts_options)
media_id = tts_generate_media_source_id(
stream = tts.async_create_stream(
self.hass,
message,
engine=pipeline.tts_engine,
engine=engine,
language=pipeline.tts_language,
options=tts_options,
)
stream.async_set_message(message)
tts_token = stream.token
media_id = stream.url
original_media_id = tts.generate_media_source_id(
self.hass,
message,
engine=engine,
language=pipeline.tts_language,
options=tts_options,
)
original_media_id = media_id
if media_source.is_media_source_id(media_id):
if not media_id_source:
@ -517,5 +531,9 @@ class AssistSatelliteEntity(entity.Entity):
media_id = async_process_play_media_url(self.hass, media_id)
return AssistSatelliteAnnouncement(
message, media_id, original_media_id, media_id_source
message=message,
media_id=media_id,
original_media_id=original_media_id,
tts_token=tts_token,
media_id_source=media_id_source,
)

View File

@ -31,6 +31,8 @@ from homeassistant.exceptions import HomeAssistantError
from . import ENTITY_ID
from .conftest import MockAssistSatellite
from tests.components.tts.common import MockResultStream
@pytest.fixture
def mock_chat_session_conversation_id() -> Generator[Mock]:
@ -186,8 +188,9 @@ async def test_new_pipeline_cancels_pipeline(
{"message": "Hello"},
AssistSatelliteAnnouncement(
message="Hello",
media_id="https://www.home-assistant.io/resolved.mp3",
media_id="http://10.10.10.10:8123/api/tts_proxy/test-token",
original_media_id="media-source://bla",
tts_token="test-token",
media_id_source="tts",
),
),
@ -200,6 +203,7 @@ async def test_new_pipeline_cancels_pipeline(
message="Hello",
media_id="https://www.home-assistant.io/resolved.mp3",
original_media_id="media-source://given",
tts_token=None,
media_id_source="media_id",
),
),
@ -209,6 +213,7 @@ async def test_new_pipeline_cancels_pipeline(
message="",
media_id="http://example.com/bla.mp3",
original_media_id="http://example.com/bla.mp3",
tts_token=None,
media_id_source="url",
),
),
@ -243,9 +248,17 @@ async def test_announce(
with (
patch(
"homeassistant.components.assist_satellite.entity.tts_generate_media_source_id",
"homeassistant.components.tts.generate_media_source_id",
new=tts_generate_media_source_id,
),
patch(
"homeassistant.components.tts.async_resolve_engine",
return_value="tts.cloud",
),
patch(
"homeassistant.components.tts.async_create_stream",
return_value=MockResultStream(hass, "wav", b""),
),
patch(
"homeassistant.components.media_source.async_resolve_media",
return_value=PlayMedia(
@ -500,7 +513,8 @@ async def test_vad_sensitivity_entity_not_found(
"Better system prompt",
AssistSatelliteAnnouncement(
message="Hello",
media_id="https://www.home-assistant.io/resolved.mp3",
media_id="http://10.10.10.10:8123/api/tts_proxy/test-token",
tts_token="test-token",
original_media_id="media-source://generated",
media_id_source="tts",
),
@ -517,6 +531,7 @@ async def test_vad_sensitivity_entity_not_found(
AssistSatelliteAnnouncement(
message="Hello",
media_id="https://www.home-assistant.io/resolved.mp3",
tts_token=None,
original_media_id="media-source://given",
media_id_source="media_id",
),
@ -530,6 +545,7 @@ async def test_vad_sensitivity_entity_not_found(
AssistSatelliteAnnouncement(
message="",
media_id="http://example.com/given.mp3",
tts_token=None,
original_media_id="http://example.com/given.mp3",
media_id_source="url",
),
@ -554,9 +570,17 @@ async def test_start_conversation(
with (
patch(
"homeassistant.components.assist_satellite.entity.tts_generate_media_source_id",
"homeassistant.components.tts.generate_media_source_id",
return_value="media-source://generated",
),
patch(
"homeassistant.components.tts.async_resolve_engine",
return_value="tts.cloud",
),
patch(
"homeassistant.components.tts.async_create_stream",
return_value=MockResultStream(hass, "wav", b""),
),
patch(
"homeassistant.components.media_source.async_resolve_media",
return_value=PlayMedia(

View File

@ -4,28 +4,28 @@ from unittest.mock import patch
import pytest
from homeassistant.components.media_source import PlayMedia
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers import intent
from homeassistant.setup import async_setup_component
from .conftest import TEST_DOMAIN, MockAssistSatellite
from tests.components.tts.common import MockResultStream
@pytest.fixture
def mock_tts():
async def mock_tts(hass: HomeAssistant):
"""Mock TTS service."""
assert await async_setup_component(hass, "tts", {})
with (
patch(
"homeassistant.components.assist_satellite.entity.tts_generate_media_source_id",
"homeassistant.components.tts.generate_media_source_id",
return_value="media-source://bla",
),
patch(
"homeassistant.components.media_source.async_resolve_media",
return_value=PlayMedia(
url="https://www.home-assistant.io/resolved.mp3",
mime_type="audio/mp3",
),
"homeassistant.components.tts.async_create_stream",
return_value=MockResultStream(hass, "wav", b""),
),
):
yield
@ -41,9 +41,13 @@ async def test_broadcast_intent(
) -> None:
"""Test we can invoke a broadcast intent."""
result = await intent.async_handle(
hass, "test", intent.INTENT_BROADCAST, {"message": {"value": "Hello"}}
)
with patch(
"homeassistant.components.tts.async_resolve_engine",
return_value="tts.cloud",
):
result = await intent.async_handle(
hass, "test", intent.INTENT_BROADCAST, {"message": {"value": "Hello"}}
)
assert result.as_dict() == {
"card": {},
@ -71,13 +75,17 @@ async def test_broadcast_intent(
assert len(entity2.announcements) == 1
assert len(entity_no_features.announcements) == 0
result = await intent.async_handle(
hass,
"test",
intent.INTENT_BROADCAST,
{"message": {"value": "Hello"}},
device_id=entity.device_entry.id,
)
with patch(
"homeassistant.components.tts.async_resolve_engine",
return_value="tts.cloud",
):
result = await intent.async_handle(
hass,
"test",
intent.INTENT_BROADCAST,
{"message": {"value": "Hello"}},
device_id=entity.device_entry.id,
)
# Broadcast doesn't targets device that triggered it.
assert result.as_dict() == {
"card": {},

View File

@ -41,7 +41,6 @@ from homeassistant.components.esphome.assist_satellite import (
EsphomeAssistSatellite,
VoiceAssistantUDPServer,
)
from homeassistant.components.media_source import PlayMedia
from homeassistant.components.select import (
DOMAIN as SELECT_DOMAIN,
SERVICE_SELECT_OPTION,
@ -57,6 +56,8 @@ from homeassistant.helpers.entity_component import EntityComponent
from .conftest import MockESPHomeDevice
from tests.components.tts.common import MockResultStream
def get_satellite_entity(
hass: HomeAssistant, mac_address: str
@ -1209,22 +1210,23 @@ async def test_announce_message(
media_id: str, timeout: float, text: str
):
assert satellite.state == AssistSatelliteState.RESPONDING
assert media_id == "https://www.home-assistant.io/resolved.mp3"
assert media_id == "http://10.10.10.10:8123/api/tts_proxy/test-token"
assert text == "test-text"
done.set()
with (
patch(
"homeassistant.components.assist_satellite.entity.tts_generate_media_source_id",
"homeassistant.components.tts.generate_media_source_id",
return_value="media-source://bla",
),
patch(
"homeassistant.components.media_source.async_resolve_media",
return_value=PlayMedia(
url="https://www.home-assistant.io/resolved.mp3",
mime_type="audio/mp3",
),
"homeassistant.components.tts.async_resolve_engine",
return_value="tts.cloud_tts",
),
patch(
"homeassistant.components.tts.async_create_stream",
return_value=MockResultStream(hass, "wav", b""),
),
patch.object(
mock_client,

View File

@ -270,6 +270,8 @@ async def mock_config_entry_setup(
class MockResultStream(ResultStream):
"""Mock result stream."""
test_set_message: str | None = None
def __init__(self, hass: HomeAssistant, extension: str, data: bytes) -> None:
"""Initialize the result stream."""
super().__init__(
@ -285,6 +287,11 @@ class MockResultStream(ResultStream):
hass.data[DATA_TTS_MANAGER].token_to_stream[self.token] = self
self._mock_data = data
@callback
def async_set_message(self, message: str) -> None:
"""Set message to be generated."""
self.test_set_message = message
async def async_stream_result(self):
"""Stream the result."""
yield self._mock_data

View File

@ -27,6 +27,8 @@ from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.setup import async_setup_component
from tests.components.tts.common import MockResultStream
_ONE_SECOND = 16000 * 2 # 16Khz 16-bit
_MEDIA_ID = "12345"
@ -879,6 +881,7 @@ async def test_announce(
announcement = assist_satellite.AssistSatelliteAnnouncement(
message="test announcement",
media_id=_MEDIA_ID,
tts_token="test-token",
original_media_id=_MEDIA_ID,
media_id_source="tts",
)
@ -926,6 +929,7 @@ async def test_voip_id_is_ip_address(
announcement = assist_satellite.AssistSatelliteAnnouncement(
message="test announcement",
media_id=_MEDIA_ID,
tts_token="test-token",
original_media_id=_MEDIA_ID,
media_id_source="tts",
)
@ -978,6 +982,7 @@ async def test_announce_timeout(
announcement = assist_satellite.AssistSatelliteAnnouncement(
message="test announcement",
media_id=_MEDIA_ID,
tts_token="test-token",
original_media_id=_MEDIA_ID,
media_id_source="tts",
)
@ -1018,6 +1023,7 @@ async def test_start_conversation(
announcement = assist_satellite.AssistSatelliteAnnouncement(
message="test announcement",
media_id=_MEDIA_ID,
tts_token="test-token",
original_media_id=_MEDIA_ID,
media_id_source="tts",
)
@ -1162,8 +1168,16 @@ async def test_start_conversation_user_doesnt_pick_up(
new=async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.assist_satellite.entity.tts_generate_media_source_id",
return_value="test media id",
"homeassistant.components.tts.generate_media_source_id",
return_value="media-source://bla",
),
patch(
"homeassistant.components.tts.async_resolve_engine",
return_value="test tts",
),
patch(
"homeassistant.components.tts.async_create_stream",
return_value=MockResultStream(hass, "wav", b""),
),
):
satellite.transport = Mock()