Allow TTS streams to generate temporary media source IDs (#145080)

* Allow TTS streams to generate temporary media source IDs

* Update tests/components/tts/test_media_source.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update assist snapshots

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
Paulus Schoutsen 2025-05-19 12:04:19 -04:00 committed by GitHub
parent cadbe885d1
commit e09dde2ea9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 71 additions and 37 deletions

View File

@ -20,9 +20,6 @@ import hass_nabucasa
import voluptuous as vol
from homeassistant.components import conversation, stt, tts, wake_word, websocket_api
from homeassistant.components.tts import (
generate_media_source_id as tts_generate_media_source_id,
)
from homeassistant.const import ATTR_SUPPORTED_FEATURES, MATCH_ALL
from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
@ -1276,26 +1273,10 @@ class PipelineRun:
)
)
try:
# Synthesize audio and get URL
tts_media_id = tts_generate_media_source_id(
self.hass,
tts_input,
engine=self.tts_stream.engine,
language=self.tts_stream.language,
options=self.tts_stream.options,
)
except Exception as src_error:
_LOGGER.exception("Unexpected error during text-to-speech")
raise TextToSpeechError(
code="tts-failed",
message="Unexpected error during text-to-speech",
) from src_error
self.tts_stream.async_set_message(tts_input)
tts_output = {
"media_id": tts_media_id,
"media_id": self.tts_stream.media_source_id,
"token": self.tts_stream.token,
"url": self.tts_stream.url,
"mime_type": self.tts_stream.content_type,

View File

@ -25,6 +25,9 @@ import voluptuous as vol
from homeassistant.components import ffmpeg, websocket_api
from homeassistant.components.http import HomeAssistantView
from homeassistant.components.media_source import (
generate_media_source_id as ms_generate_media_source_id,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import EVENT_HOMEASSISTANT_STOP, PLATFORM_FORMAT
from homeassistant.core import (
@ -58,6 +61,7 @@ from .const import (
DEFAULT_CACHE_DIR,
DEFAULT_TIME_MEMORY,
DOMAIN,
MEDIA_SOURCE_STREAM_PATH,
TtsAudioType,
)
from .entity import TextToSpeechEntity, TTSAudioRequest, TTSAudioResponse
@ -273,9 +277,17 @@ async def async_get_media_source_audio(
media_source_id: str,
) -> tuple[str, bytes]:
"""Get TTS audio as extension, data."""
manager = hass.data[DATA_TTS_MANAGER]
parsed = parse_media_source_id(media_source_id)
stream = hass.data[DATA_TTS_MANAGER].async_create_result_stream(**parsed["options"])
stream.async_set_message(parsed["message"])
if "stream" in parsed:
stream = manager.async_get_result_stream(
parsed["stream"] # type: ignore[typeddict-item]
)
if stream is None:
raise ValueError("Stream not found")
else:
stream = manager.async_create_result_stream(**parsed["options"])
stream.async_set_message(parsed["message"])
data = b"".join([chunk async for chunk in stream.async_stream_result()])
return stream.extension, data
@ -478,6 +490,14 @@ class ResultStream:
"""Get the URL to stream the result."""
return f"/api/tts_proxy/{self.token}"
@cached_property
def media_source_id(self) -> str:
"""Get the media source ID of this stream."""
return ms_generate_media_source_id(
DOMAIN,
f"{MEDIA_SOURCE_STREAM_PATH}/{self.token}",
)
@cached_property
def _result_cache(self) -> asyncio.Future[TTSCache]:
"""Get the future that returns the cache."""

View File

@ -30,4 +30,6 @@ DATA_COMPONENT: HassKey[EntityComponent[TextToSpeechEntity]] = HassKey(DOMAIN)
DATA_TTS_MANAGER: HassKey[SpeechManager] = HassKey("tts_manager")
MEDIA_SOURCE_STREAM_PATH = "-stream-"
type TtsAudioType = tuple[str | None, bytes | None]

View File

@ -19,7 +19,7 @@ from homeassistant.components.media_source import (
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
from .const import DATA_COMPONENT, DATA_TTS_MANAGER, DOMAIN
from .const import DATA_COMPONENT, DATA_TTS_MANAGER, DOMAIN, MEDIA_SOURCE_STREAM_PATH
from .helper import get_engine_instance
URL_QUERY_TTS_OPTIONS = "tts_options"
@ -81,10 +81,22 @@ class ParsedMediaSourceId(TypedDict):
message: str
class ParsedMediaSourceStreamId(TypedDict):
"""Parsed media source ID for a stream."""
stream: str
@callback
def parse_media_source_id(media_source_id: str) -> ParsedMediaSourceId:
def parse_media_source_id(
media_source_id: str,
) -> ParsedMediaSourceId | ParsedMediaSourceStreamId:
"""Turn a media source ID into options."""
parsed = URL(media_source_id)
if parsed.path.startswith(f"{MEDIA_SOURCE_STREAM_PATH}/"):
return {"stream": parsed.path[len(MEDIA_SOURCE_STREAM_PATH) + 1 :]}
if URL_QUERY_TTS_OPTIONS in parsed.query:
try:
options = json.loads(parsed.query[URL_QUERY_TTS_OPTIONS])
@ -122,17 +134,24 @@ class TTSMediaSource(MediaSource):
async def async_resolve_media(self, item: MediaSourceItem) -> PlayMedia:
"""Resolve media to a url."""
manager = self.hass.data[DATA_TTS_MANAGER]
try:
parsed = parse_media_source_id(item.identifier)
stream = self.hass.data[DATA_TTS_MANAGER].async_create_result_stream(
**parsed["options"]
)
stream.async_set_message(parsed["message"])
if "stream" in parsed:
stream = manager.async_get_result_stream(
parsed["stream"], # type: ignore[typeddict-item]
)
else:
stream = manager.async_create_result_stream(**parsed["options"])
stream.async_set_message(parsed["message"])
except Unresolvable:
raise
except HomeAssistantError as err:
raise Unresolvable(str(err)) from err
if stream is None:
raise Unresolvable("Stream not found")
return PlayMedia(stream.url, stream.content_type)
async def async_browse_media(

View File

@ -84,7 +84,7 @@
dict({
'data': dict({
'tts_output': dict({
'media_id': "media-source://tts/tts.test?message=Sorry,+I+couldn't+understand+that&language=en_US&tts_options=%7B%7D",
'media_id': 'media-source://tts/-stream-/test_token.mp3',
'mime_type': 'audio/mpeg',
'token': 'test_token.mp3',
'url': '/api/tts_proxy/test_token.mp3',
@ -183,7 +183,7 @@
dict({
'data': dict({
'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22Arnold+Schwarzenegger%22%7D",
'media_id': 'media-source://tts/-stream-/test_token.mp3',
'mime_type': 'audio/mpeg',
'token': 'test_token.mp3',
'url': '/api/tts_proxy/test_token.mp3',
@ -282,7 +282,7 @@
dict({
'data': dict({
'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22Arnold+Schwarzenegger%22%7D",
'media_id': 'media-source://tts/-stream-/test_token.mp3',
'mime_type': 'audio/mpeg',
'token': 'test_token.mp3',
'url': '/api/tts_proxy/test_token.mp3',
@ -405,7 +405,7 @@
dict({
'data': dict({
'tts_output': dict({
'media_id': "media-source://tts/tts.test?message=Sorry,+I+couldn't+understand+that&language=en_US&tts_options=%7B%7D",
'media_id': 'media-source://tts/-stream-/test_token.mp3',
'mime_type': 'audio/mpeg',
'token': 'test_token.mp3',
'url': '/api/tts_proxy/test_token.mp3',

View File

@ -139,7 +139,7 @@
dict({
'data': dict({
'tts_output': dict({
'media_id': 'media-source://tts/tts.test?message=hello,+how+are+you?&language=en_US&tts_options=%7B%7D',
'media_id': 'media-source://tts/-stream-/mocked-token.mp3',
'mime_type': 'audio/mpeg',
'token': 'mocked-token.mp3',
'url': '/api/tts_proxy/mocked-token.mp3',

View File

@ -80,7 +80,7 @@
# name: test_audio_pipeline.6
dict({
'tts_output': dict({
'media_id': "media-source://tts/tts.test?message=Sorry,+I+couldn't+understand+that&language=en_US&tts_options=%7B%7D",
'media_id': 'media-source://tts/-stream-/test_token.mp3',
'mime_type': 'audio/mpeg',
'token': 'test_token.mp3',
'url': '/api/tts_proxy/test_token.mp3',
@ -171,7 +171,7 @@
# name: test_audio_pipeline_debug.6
dict({
'tts_output': dict({
'media_id': "media-source://tts/tts.test?message=Sorry,+I+couldn't+understand+that&language=en_US&tts_options=%7B%7D",
'media_id': 'media-source://tts/-stream-/test_token.mp3',
'mime_type': 'audio/mpeg',
'token': 'test_token.mp3',
'url': '/api/tts_proxy/test_token.mp3',
@ -274,7 +274,7 @@
# name: test_audio_pipeline_with_enhancements.6
dict({
'tts_output': dict({
'media_id': "media-source://tts/tts.test?message=Sorry,+I+couldn't+understand+that&language=en_US&tts_options=%7B%7D",
'media_id': 'media-source://tts/-stream-/test_token.mp3',
'mime_type': 'audio/mpeg',
'token': 'test_token.mp3',
'url': '/api/tts_proxy/test_token.mp3',
@ -387,7 +387,7 @@
# name: test_audio_pipeline_with_wake_word_no_timeout.8
dict({
'tts_output': dict({
'media_id': "media-source://tts/tts.test?message=Sorry,+I+couldn't+understand+that&language=en_US&tts_options=%7B%7D",
'media_id': 'media-source://tts/-stream-/test_token.mp3',
'mime_type': 'audio/mpeg',
'token': 'test_token.mp3',
'url': '/api/tts_proxy/test_token.mp3',

View File

@ -17,6 +17,7 @@ from homeassistant.setup import async_setup_component
from .common import (
DEFAULT_LANG,
MockResultStream,
MockTTSEntity,
MockTTSProvider,
mock_config_entry_setup,
@ -198,6 +199,17 @@ async def test_resolving(
assert language == "de_DE"
assert mock_get_tts_audio.mock_calls[0][2]["options"] == {"voice": "Paulus"}
# Test with result stream
stream = MockResultStream(hass, "wav", b"")
media = await media_source.async_resolve_media(hass, stream.media_source_id, None)
assert media.url == stream.url
assert media.mime_type == stream.content_type
with pytest.raises(media_source.Unresolvable):
await media_source.async_resolve_media(
hass, "media-source://tts/-stream-/not-a-valid-token", None
)
@pytest.mark.parametrize(
("mock_provider", "mock_tts_entity"),