mirror of
https://github.com/home-assistant/core.git
synced 2025-07-22 12:47:08 +00:00
Allow TTS streams to generate temporary media source IDs (#145080)
* Allow TTS streams to generate temporary media source IDs * Update tests/components/tts/test_media_source.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update assist snapshots --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
parent
cadbe885d1
commit
e09dde2ea9
@ -20,9 +20,6 @@ import hass_nabucasa
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import conversation, stt, tts, wake_word, websocket_api
|
||||
from homeassistant.components.tts import (
|
||||
generate_media_source_id as tts_generate_media_source_id,
|
||||
)
|
||||
from homeassistant.const import ATTR_SUPPORTED_FEATURES, MATCH_ALL
|
||||
from homeassistant.core import Context, HomeAssistant, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
@ -1276,26 +1273,10 @@ class PipelineRun:
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
# Synthesize audio and get URL
|
||||
tts_media_id = tts_generate_media_source_id(
|
||||
self.hass,
|
||||
tts_input,
|
||||
engine=self.tts_stream.engine,
|
||||
language=self.tts_stream.language,
|
||||
options=self.tts_stream.options,
|
||||
)
|
||||
except Exception as src_error:
|
||||
_LOGGER.exception("Unexpected error during text-to-speech")
|
||||
raise TextToSpeechError(
|
||||
code="tts-failed",
|
||||
message="Unexpected error during text-to-speech",
|
||||
) from src_error
|
||||
|
||||
self.tts_stream.async_set_message(tts_input)
|
||||
|
||||
tts_output = {
|
||||
"media_id": tts_media_id,
|
||||
"media_id": self.tts_stream.media_source_id,
|
||||
"token": self.tts_stream.token,
|
||||
"url": self.tts_stream.url,
|
||||
"mime_type": self.tts_stream.content_type,
|
||||
|
@ -25,6 +25,9 @@ import voluptuous as vol
|
||||
|
||||
from homeassistant.components import ffmpeg, websocket_api
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
from homeassistant.components.media_source import (
|
||||
generate_media_source_id as ms_generate_media_source_id,
|
||||
)
|
||||
from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_STOP, PLATFORM_FORMAT
|
||||
from homeassistant.core import (
|
||||
@ -58,6 +61,7 @@ from .const import (
|
||||
DEFAULT_CACHE_DIR,
|
||||
DEFAULT_TIME_MEMORY,
|
||||
DOMAIN,
|
||||
MEDIA_SOURCE_STREAM_PATH,
|
||||
TtsAudioType,
|
||||
)
|
||||
from .entity import TextToSpeechEntity, TTSAudioRequest, TTSAudioResponse
|
||||
@ -273,9 +277,17 @@ async def async_get_media_source_audio(
|
||||
media_source_id: str,
|
||||
) -> tuple[str, bytes]:
|
||||
"""Get TTS audio as extension, data."""
|
||||
manager = hass.data[DATA_TTS_MANAGER]
|
||||
parsed = parse_media_source_id(media_source_id)
|
||||
stream = hass.data[DATA_TTS_MANAGER].async_create_result_stream(**parsed["options"])
|
||||
stream.async_set_message(parsed["message"])
|
||||
if "stream" in parsed:
|
||||
stream = manager.async_get_result_stream(
|
||||
parsed["stream"] # type: ignore[typeddict-item]
|
||||
)
|
||||
if stream is None:
|
||||
raise ValueError("Stream not found")
|
||||
else:
|
||||
stream = manager.async_create_result_stream(**parsed["options"])
|
||||
stream.async_set_message(parsed["message"])
|
||||
data = b"".join([chunk async for chunk in stream.async_stream_result()])
|
||||
return stream.extension, data
|
||||
|
||||
@ -478,6 +490,14 @@ class ResultStream:
|
||||
"""Get the URL to stream the result."""
|
||||
return f"/api/tts_proxy/{self.token}"
|
||||
|
||||
@cached_property
|
||||
def media_source_id(self) -> str:
|
||||
"""Get the media source ID of this stream."""
|
||||
return ms_generate_media_source_id(
|
||||
DOMAIN,
|
||||
f"{MEDIA_SOURCE_STREAM_PATH}/{self.token}",
|
||||
)
|
||||
|
||||
@cached_property
|
||||
def _result_cache(self) -> asyncio.Future[TTSCache]:
|
||||
"""Get the future that returns the cache."""
|
||||
|
@ -30,4 +30,6 @@ DATA_COMPONENT: HassKey[EntityComponent[TextToSpeechEntity]] = HassKey(DOMAIN)
|
||||
|
||||
DATA_TTS_MANAGER: HassKey[SpeechManager] = HassKey("tts_manager")
|
||||
|
||||
MEDIA_SOURCE_STREAM_PATH = "-stream-"
|
||||
|
||||
type TtsAudioType = tuple[str | None, bytes | None]
|
||||
|
@ -19,7 +19,7 @@ from homeassistant.components.media_source import (
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
|
||||
from .const import DATA_COMPONENT, DATA_TTS_MANAGER, DOMAIN
|
||||
from .const import DATA_COMPONENT, DATA_TTS_MANAGER, DOMAIN, MEDIA_SOURCE_STREAM_PATH
|
||||
from .helper import get_engine_instance
|
||||
|
||||
URL_QUERY_TTS_OPTIONS = "tts_options"
|
||||
@ -81,10 +81,22 @@ class ParsedMediaSourceId(TypedDict):
|
||||
message: str
|
||||
|
||||
|
||||
class ParsedMediaSourceStreamId(TypedDict):
|
||||
"""Parsed media source ID for a stream."""
|
||||
|
||||
stream: str
|
||||
|
||||
|
||||
@callback
|
||||
def parse_media_source_id(media_source_id: str) -> ParsedMediaSourceId:
|
||||
def parse_media_source_id(
|
||||
media_source_id: str,
|
||||
) -> ParsedMediaSourceId | ParsedMediaSourceStreamId:
|
||||
"""Turn a media source ID into options."""
|
||||
parsed = URL(media_source_id)
|
||||
|
||||
if parsed.path.startswith(f"{MEDIA_SOURCE_STREAM_PATH}/"):
|
||||
return {"stream": parsed.path[len(MEDIA_SOURCE_STREAM_PATH) + 1 :]}
|
||||
|
||||
if URL_QUERY_TTS_OPTIONS in parsed.query:
|
||||
try:
|
||||
options = json.loads(parsed.query[URL_QUERY_TTS_OPTIONS])
|
||||
@ -122,17 +134,24 @@ class TTSMediaSource(MediaSource):
|
||||
|
||||
async def async_resolve_media(self, item: MediaSourceItem) -> PlayMedia:
|
||||
"""Resolve media to a url."""
|
||||
manager = self.hass.data[DATA_TTS_MANAGER]
|
||||
try:
|
||||
parsed = parse_media_source_id(item.identifier)
|
||||
stream = self.hass.data[DATA_TTS_MANAGER].async_create_result_stream(
|
||||
**parsed["options"]
|
||||
)
|
||||
stream.async_set_message(parsed["message"])
|
||||
if "stream" in parsed:
|
||||
stream = manager.async_get_result_stream(
|
||||
parsed["stream"], # type: ignore[typeddict-item]
|
||||
)
|
||||
else:
|
||||
stream = manager.async_create_result_stream(**parsed["options"])
|
||||
stream.async_set_message(parsed["message"])
|
||||
except Unresolvable:
|
||||
raise
|
||||
except HomeAssistantError as err:
|
||||
raise Unresolvable(str(err)) from err
|
||||
|
||||
if stream is None:
|
||||
raise Unresolvable("Stream not found")
|
||||
|
||||
return PlayMedia(stream.url, stream.content_type)
|
||||
|
||||
async def async_browse_media(
|
||||
|
@ -84,7 +84,7 @@
|
||||
dict({
|
||||
'data': dict({
|
||||
'tts_output': dict({
|
||||
'media_id': "media-source://tts/tts.test?message=Sorry,+I+couldn't+understand+that&language=en_US&tts_options=%7B%7D",
|
||||
'media_id': 'media-source://tts/-stream-/test_token.mp3',
|
||||
'mime_type': 'audio/mpeg',
|
||||
'token': 'test_token.mp3',
|
||||
'url': '/api/tts_proxy/test_token.mp3',
|
||||
@ -183,7 +183,7 @@
|
||||
dict({
|
||||
'data': dict({
|
||||
'tts_output': dict({
|
||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22Arnold+Schwarzenegger%22%7D",
|
||||
'media_id': 'media-source://tts/-stream-/test_token.mp3',
|
||||
'mime_type': 'audio/mpeg',
|
||||
'token': 'test_token.mp3',
|
||||
'url': '/api/tts_proxy/test_token.mp3',
|
||||
@ -282,7 +282,7 @@
|
||||
dict({
|
||||
'data': dict({
|
||||
'tts_output': dict({
|
||||
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22Arnold+Schwarzenegger%22%7D",
|
||||
'media_id': 'media-source://tts/-stream-/test_token.mp3',
|
||||
'mime_type': 'audio/mpeg',
|
||||
'token': 'test_token.mp3',
|
||||
'url': '/api/tts_proxy/test_token.mp3',
|
||||
@ -405,7 +405,7 @@
|
||||
dict({
|
||||
'data': dict({
|
||||
'tts_output': dict({
|
||||
'media_id': "media-source://tts/tts.test?message=Sorry,+I+couldn't+understand+that&language=en_US&tts_options=%7B%7D",
|
||||
'media_id': 'media-source://tts/-stream-/test_token.mp3',
|
||||
'mime_type': 'audio/mpeg',
|
||||
'token': 'test_token.mp3',
|
||||
'url': '/api/tts_proxy/test_token.mp3',
|
||||
|
@ -139,7 +139,7 @@
|
||||
dict({
|
||||
'data': dict({
|
||||
'tts_output': dict({
|
||||
'media_id': 'media-source://tts/tts.test?message=hello,+how+are+you?&language=en_US&tts_options=%7B%7D',
|
||||
'media_id': 'media-source://tts/-stream-/mocked-token.mp3',
|
||||
'mime_type': 'audio/mpeg',
|
||||
'token': 'mocked-token.mp3',
|
||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||
|
@ -80,7 +80,7 @@
|
||||
# name: test_audio_pipeline.6
|
||||
dict({
|
||||
'tts_output': dict({
|
||||
'media_id': "media-source://tts/tts.test?message=Sorry,+I+couldn't+understand+that&language=en_US&tts_options=%7B%7D",
|
||||
'media_id': 'media-source://tts/-stream-/test_token.mp3',
|
||||
'mime_type': 'audio/mpeg',
|
||||
'token': 'test_token.mp3',
|
||||
'url': '/api/tts_proxy/test_token.mp3',
|
||||
@ -171,7 +171,7 @@
|
||||
# name: test_audio_pipeline_debug.6
|
||||
dict({
|
||||
'tts_output': dict({
|
||||
'media_id': "media-source://tts/tts.test?message=Sorry,+I+couldn't+understand+that&language=en_US&tts_options=%7B%7D",
|
||||
'media_id': 'media-source://tts/-stream-/test_token.mp3',
|
||||
'mime_type': 'audio/mpeg',
|
||||
'token': 'test_token.mp3',
|
||||
'url': '/api/tts_proxy/test_token.mp3',
|
||||
@ -274,7 +274,7 @@
|
||||
# name: test_audio_pipeline_with_enhancements.6
|
||||
dict({
|
||||
'tts_output': dict({
|
||||
'media_id': "media-source://tts/tts.test?message=Sorry,+I+couldn't+understand+that&language=en_US&tts_options=%7B%7D",
|
||||
'media_id': 'media-source://tts/-stream-/test_token.mp3',
|
||||
'mime_type': 'audio/mpeg',
|
||||
'token': 'test_token.mp3',
|
||||
'url': '/api/tts_proxy/test_token.mp3',
|
||||
@ -387,7 +387,7 @@
|
||||
# name: test_audio_pipeline_with_wake_word_no_timeout.8
|
||||
dict({
|
||||
'tts_output': dict({
|
||||
'media_id': "media-source://tts/tts.test?message=Sorry,+I+couldn't+understand+that&language=en_US&tts_options=%7B%7D",
|
||||
'media_id': 'media-source://tts/-stream-/test_token.mp3',
|
||||
'mime_type': 'audio/mpeg',
|
||||
'token': 'test_token.mp3',
|
||||
'url': '/api/tts_proxy/test_token.mp3',
|
||||
|
@ -17,6 +17,7 @@ from homeassistant.setup import async_setup_component
|
||||
|
||||
from .common import (
|
||||
DEFAULT_LANG,
|
||||
MockResultStream,
|
||||
MockTTSEntity,
|
||||
MockTTSProvider,
|
||||
mock_config_entry_setup,
|
||||
@ -198,6 +199,17 @@ async def test_resolving(
|
||||
assert language == "de_DE"
|
||||
assert mock_get_tts_audio.mock_calls[0][2]["options"] == {"voice": "Paulus"}
|
||||
|
||||
# Test with result stream
|
||||
stream = MockResultStream(hass, "wav", b"")
|
||||
media = await media_source.async_resolve_media(hass, stream.media_source_id, None)
|
||||
assert media.url == stream.url
|
||||
assert media.mime_type == stream.content_type
|
||||
|
||||
with pytest.raises(media_source.Unresolvable):
|
||||
await media_source.async_resolve_media(
|
||||
hass, "media-source://tts/-stream-/not-a-valid-token", None
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("mock_provider", "mock_tts_entity"),
|
||||
|
Loading…
x
Reference in New Issue
Block a user