Hide TTS filename behind random token (#131192)

* Hide TTS filename behind random token

* Clean up and fix test snapshots

* Fix tests

* Fix cloud tests
This commit is contained in:
Michael Hansen 2024-11-24 19:52:21 -06:00 committed by GitHub
parent cb4636ada1
commit d4071e7123
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 694 additions and 657 deletions

View File

@ -13,6 +13,7 @@ import logging
import mimetypes import mimetypes
import os import os
import re import re
import secrets
import subprocess import subprocess
import tempfile import tempfile
from typing import Any, Final, TypedDict, final from typing import Any, Final, TypedDict, final
@ -540,6 +541,10 @@ class SpeechManager:
self.file_cache: dict[str, str] = {} self.file_cache: dict[str, str] = {}
self.mem_cache: dict[str, TTSCache] = {} self.mem_cache: dict[str, TTSCache] = {}
# filename <-> token
self.filename_to_token: dict[str, str] = {}
self.token_to_filename: dict[str, str] = {}
def _init_cache(self) -> dict[str, str]: def _init_cache(self) -> dict[str, str]:
"""Init cache folder and fetch files.""" """Init cache folder and fetch files."""
try: try:
@ -656,7 +661,17 @@ class SpeechManager:
engine_instance, cache_key, message, use_cache, language, options engine_instance, cache_key, message, use_cache, language, options
) )
return f"/api/tts_proxy/{filename}" # Use a randomly generated token instead of exposing the filename
token = self.filename_to_token.get(filename)
if not token:
# Keep extension (.mp3, etc.)
token = secrets.token_urlsafe(16) + os.path.splitext(filename)[1]
# Map token <-> filename
self.filename_to_token[filename] = token
self.token_to_filename[token] = filename
return f"/api/tts_proxy/{token}"
async def async_get_tts_audio( async def async_get_tts_audio(
self, self,
@ -910,11 +925,15 @@ class SpeechManager:
), ),
) )
async def async_read_tts(self, filename: str) -> tuple[str | None, bytes]: async def async_read_tts(self, token: str) -> tuple[str | None, bytes]:
"""Read a voice file and return binary. """Read a voice file and return binary.
This method is a coroutine. This method is a coroutine.
""" """
filename = self.token_to_filename.get(token)
if not filename:
raise HomeAssistantError(f"{token} was not recognized!")
if not (record := _RE_VOICE_FILE.match(filename.lower())) and not ( if not (record := _RE_VOICE_FILE.match(filename.lower())) and not (
record := _RE_LEGACY_VOICE_FILE.match(filename.lower()) record := _RE_LEGACY_VOICE_FILE.match(filename.lower())
): ):
@ -1076,6 +1095,7 @@ class TextToSpeechView(HomeAssistantView):
async def get(self, request: web.Request, filename: str) -> web.Response: async def get(self, request: web.Request, filename: str) -> web.Response:
"""Start a get request.""" """Start a get request."""
try: try:
# filename is actually token, but we keep its name for compatibility
content, data = await self.tts.async_read_tts(filename) content, data = await self.tts.async_read_tts(filename)
except HomeAssistantError as err: except HomeAssistantError as err:
_LOGGER.error("Error on load tts: %s", err) _LOGGER.error("Error on load tts: %s", err)

View File

@ -77,7 +77,7 @@
'tts_output': dict({ 'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D", 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D",
'mime_type': 'audio/mpeg', 'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3', 'url': '/api/tts_proxy/test_token.mp3',
}), }),
}), }),
'type': <PipelineEventType.TTS_END: 'tts-end'>, 'type': <PipelineEventType.TTS_END: 'tts-end'>,
@ -166,7 +166,7 @@
'tts_output': dict({ 'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22Arnold+Schwarzenegger%22%7D", 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22Arnold+Schwarzenegger%22%7D",
'mime_type': 'audio/mpeg', 'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_2657c1a8ee_test.mp3', 'url': '/api/tts_proxy/test_token.mp3',
}), }),
}), }),
'type': <PipelineEventType.TTS_END: 'tts-end'>, 'type': <PipelineEventType.TTS_END: 'tts-end'>,
@ -255,7 +255,7 @@
'tts_output': dict({ 'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22Arnold+Schwarzenegger%22%7D", 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22Arnold+Schwarzenegger%22%7D",
'mime_type': 'audio/mpeg', 'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_2657c1a8ee_test.mp3', 'url': '/api/tts_proxy/test_token.mp3',
}), }),
}), }),
'type': <PipelineEventType.TTS_END: 'tts-end'>, 'type': <PipelineEventType.TTS_END: 'tts-end'>,
@ -368,7 +368,7 @@
'tts_output': dict({ 'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D", 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D",
'mime_type': 'audio/mpeg', 'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3', 'url': '/api/tts_proxy/test_token.mp3',
}), }),
}), }),
'type': <PipelineEventType.TTS_END: 'tts-end'>, 'type': <PipelineEventType.TTS_END: 'tts-end'>,

View File

@ -73,7 +73,7 @@
'tts_output': dict({ 'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D", 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D",
'mime_type': 'audio/mpeg', 'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3', 'url': '/api/tts_proxy/test_token.mp3',
}), }),
}) })
# --- # ---
@ -154,7 +154,7 @@
'tts_output': dict({ 'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D", 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D",
'mime_type': 'audio/mpeg', 'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3', 'url': '/api/tts_proxy/test_token.mp3',
}), }),
}) })
# --- # ---
@ -247,7 +247,7 @@
'tts_output': dict({ 'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D", 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D",
'mime_type': 'audio/mpeg', 'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3', 'url': '/api/tts_proxy/test_token.mp3',
}), }),
}) })
# --- # ---
@ -350,7 +350,7 @@
'tts_output': dict({ 'tts_output': dict({
'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D", 'media_id': "media-source://tts/test?message=Sorry,+I+couldn't+understand+that&language=en-US&tts_options=%7B%22voice%22:%22james_earl_jones%22%7D",
'mime_type': 'audio/mpeg', 'mime_type': 'audio/mpeg',
'url': '/api/tts_proxy/dae2cdcb27a1d1c3b07ba2c7db91480f9d4bfd8f_en-us_031e2ec052_test.mp3', 'url': '/api/tts_proxy/test_token.mp3',
}), }),
}) })
# --- # ---

View File

@ -70,21 +70,24 @@ async def test_pipeline_from_audio_stream_auto(
yield make_10ms_chunk(b"part2") yield make_10ms_chunk(b"part2")
yield b"" yield b""
await assist_pipeline.async_pipeline_from_audio_stream( with patch(
hass, "homeassistant.components.tts.secrets.token_urlsafe", return_value="test_token"
context=Context(), ):
event_callback=events.append, await assist_pipeline.async_pipeline_from_audio_stream(
stt_metadata=stt.SpeechMetadata( hass,
language="", context=Context(),
format=stt.AudioFormats.WAV, event_callback=events.append,
codec=stt.AudioCodecs.PCM, stt_metadata=stt.SpeechMetadata(
bit_rate=stt.AudioBitRates.BITRATE_16, language="",
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, format=stt.AudioFormats.WAV,
channel=stt.AudioChannels.CHANNEL_MONO, codec=stt.AudioCodecs.PCM,
), bit_rate=stt.AudioBitRates.BITRATE_16,
stt_stream=audio_data(), sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False), channel=stt.AudioChannels.CHANNEL_MONO,
) ),
stt_stream=audio_data(),
audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False),
)
assert process_events(events) == snapshot assert process_events(events) == snapshot
assert len(mock_stt_provider_entity.received) == 2 assert len(mock_stt_provider_entity.received) == 2
@ -133,23 +136,26 @@ async def test_pipeline_from_audio_stream_legacy(
assert msg["success"] assert msg["success"]
pipeline_id = msg["result"]["id"] pipeline_id = msg["result"]["id"]
# Use the created pipeline with patch(
await assist_pipeline.async_pipeline_from_audio_stream( "homeassistant.components.tts.secrets.token_urlsafe", return_value="test_token"
hass, ):
context=Context(), # Use the created pipeline
event_callback=events.append, await assist_pipeline.async_pipeline_from_audio_stream(
stt_metadata=stt.SpeechMetadata( hass,
language="en-UK", context=Context(),
format=stt.AudioFormats.WAV, event_callback=events.append,
codec=stt.AudioCodecs.PCM, stt_metadata=stt.SpeechMetadata(
bit_rate=stt.AudioBitRates.BITRATE_16, language="en-UK",
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, format=stt.AudioFormats.WAV,
channel=stt.AudioChannels.CHANNEL_MONO, codec=stt.AudioCodecs.PCM,
), bit_rate=stt.AudioBitRates.BITRATE_16,
stt_stream=audio_data(), sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
pipeline_id=pipeline_id, channel=stt.AudioChannels.CHANNEL_MONO,
audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False), ),
) stt_stream=audio_data(),
pipeline_id=pipeline_id,
audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False),
)
assert process_events(events) == snapshot assert process_events(events) == snapshot
assert len(mock_stt_provider.received) == 2 assert len(mock_stt_provider.received) == 2
@ -198,23 +204,26 @@ async def test_pipeline_from_audio_stream_entity(
assert msg["success"] assert msg["success"]
pipeline_id = msg["result"]["id"] pipeline_id = msg["result"]["id"]
# Use the created pipeline with patch(
await assist_pipeline.async_pipeline_from_audio_stream( "homeassistant.components.tts.secrets.token_urlsafe", return_value="test_token"
hass, ):
context=Context(), # Use the created pipeline
event_callback=events.append, await assist_pipeline.async_pipeline_from_audio_stream(
stt_metadata=stt.SpeechMetadata( hass,
language="en-UK", context=Context(),
format=stt.AudioFormats.WAV, event_callback=events.append,
codec=stt.AudioCodecs.PCM, stt_metadata=stt.SpeechMetadata(
bit_rate=stt.AudioBitRates.BITRATE_16, language="en-UK",
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, format=stt.AudioFormats.WAV,
channel=stt.AudioChannels.CHANNEL_MONO, codec=stt.AudioCodecs.PCM,
), bit_rate=stt.AudioBitRates.BITRATE_16,
stt_stream=audio_data(), sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
pipeline_id=pipeline_id, channel=stt.AudioChannels.CHANNEL_MONO,
audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False), ),
) stt_stream=audio_data(),
pipeline_id=pipeline_id,
audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False),
)
assert process_events(events) == snapshot assert process_events(events) == snapshot
assert len(mock_stt_provider_entity.received) == 2 assert len(mock_stt_provider_entity.received) == 2
@ -362,25 +371,28 @@ async def test_pipeline_from_audio_stream_wake_word(
yield b"" yield b""
await assist_pipeline.async_pipeline_from_audio_stream( with patch(
hass, "homeassistant.components.tts.secrets.token_urlsafe", return_value="test_token"
context=Context(), ):
event_callback=events.append, await assist_pipeline.async_pipeline_from_audio_stream(
stt_metadata=stt.SpeechMetadata( hass,
language="", context=Context(),
format=stt.AudioFormats.WAV, event_callback=events.append,
codec=stt.AudioCodecs.PCM, stt_metadata=stt.SpeechMetadata(
bit_rate=stt.AudioBitRates.BITRATE_16, language="",
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000, format=stt.AudioFormats.WAV,
channel=stt.AudioChannels.CHANNEL_MONO, codec=stt.AudioCodecs.PCM,
), bit_rate=stt.AudioBitRates.BITRATE_16,
stt_stream=audio_data(), sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
start_stage=assist_pipeline.PipelineStage.WAKE_WORD, channel=stt.AudioChannels.CHANNEL_MONO,
wake_word_settings=assist_pipeline.WakeWordSettings( ),
audio_seconds_to_buffer=1.5 stt_stream=audio_data(),
), start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False), wake_word_settings=assist_pipeline.WakeWordSettings(
) audio_seconds_to_buffer=1.5
),
audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False),
)
assert process_events(events) == snapshot assert process_events(events) == snapshot

View File

@ -119,85 +119,88 @@ async def test_audio_pipeline(
events = [] events = []
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
await client.send_json_auto_id( with patch(
{ "homeassistant.components.tts.secrets.token_urlsafe", return_value="test_token"
"type": "assist_pipeline/run", ):
"start_stage": "stt", await client.send_json_auto_id(
"end_stage": "tts", {
"input": { "type": "assist_pipeline/run",
"sample_rate": 44100, "start_stage": "stt",
}, "end_stage": "tts",
} "input": {
) "sample_rate": 44100,
},
}
)
# result # result
msg = await client.receive_json() msg = await client.receive_json()
assert msg["success"] assert msg["success"]
# run start # run start
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "run-start" assert msg["event"]["type"] == "run-start"
msg["event"]["data"]["pipeline"] = ANY msg["event"]["data"]["pipeline"] = ANY
assert msg["event"]["data"] == snapshot assert msg["event"]["data"] == snapshot
handler_id = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"] handler_id = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"]
events.append(msg["event"]) events.append(msg["event"])
# stt # stt
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "stt-start" assert msg["event"]["type"] == "stt-start"
assert msg["event"]["data"] == snapshot assert msg["event"]["data"] == snapshot
events.append(msg["event"]) events.append(msg["event"])
# End of audio stream (handler id + empty payload) # End of audio stream (handler id + empty payload)
await client.send_bytes(bytes([handler_id])) await client.send_bytes(bytes([handler_id]))
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "stt-end" assert msg["event"]["type"] == "stt-end"
assert msg["event"]["data"] == snapshot assert msg["event"]["data"] == snapshot
events.append(msg["event"]) events.append(msg["event"])
# intent # intent
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "intent-start" assert msg["event"]["type"] == "intent-start"
assert msg["event"]["data"] == snapshot assert msg["event"]["data"] == snapshot
events.append(msg["event"]) events.append(msg["event"])
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "intent-end" assert msg["event"]["type"] == "intent-end"
assert msg["event"]["data"] == snapshot assert msg["event"]["data"] == snapshot
events.append(msg["event"]) events.append(msg["event"])
# text-to-speech # text-to-speech
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "tts-start" assert msg["event"]["type"] == "tts-start"
assert msg["event"]["data"] == snapshot assert msg["event"]["data"] == snapshot
events.append(msg["event"]) events.append(msg["event"])
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "tts-end" assert msg["event"]["type"] == "tts-end"
assert msg["event"]["data"] == snapshot assert msg["event"]["data"] == snapshot
events.append(msg["event"]) events.append(msg["event"])
# run end # run end
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "run-end" assert msg["event"]["type"] == "run-end"
assert msg["event"]["data"] == snapshot assert msg["event"]["data"] == snapshot
events.append(msg["event"]) events.append(msg["event"])
pipeline_data: PipelineData = hass.data[DOMAIN] pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_id = list(pipeline_data.pipeline_debug)[0] pipeline_id = list(pipeline_data.pipeline_debug)[0]
pipeline_run_id = list(pipeline_data.pipeline_debug[pipeline_id])[0] pipeline_run_id = list(pipeline_data.pipeline_debug[pipeline_id])[0]
await client.send_json_auto_id( await client.send_json_auto_id(
{ {
"type": "assist_pipeline/pipeline_debug/get", "type": "assist_pipeline/pipeline_debug/get",
"pipeline_id": pipeline_id, "pipeline_id": pipeline_id,
"pipeline_run_id": pipeline_run_id, "pipeline_run_id": pipeline_run_id,
} }
) )
msg = await client.receive_json() msg = await client.receive_json()
assert msg["success"] assert msg["success"]
assert msg["result"] == {"events": events} assert msg["result"] == {"events": events}
async def test_audio_pipeline_with_wake_word_timeout( async def test_audio_pipeline_with_wake_word_timeout(
@ -210,49 +213,52 @@ async def test_audio_pipeline_with_wake_word_timeout(
events = [] events = []
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
await client.send_json_auto_id( with patch(
{ "homeassistant.components.tts.secrets.token_urlsafe", return_value="test_token"
"type": "assist_pipeline/run", ):
"start_stage": "wake_word", await client.send_json_auto_id(
"end_stage": "tts", {
"input": { "type": "assist_pipeline/run",
"sample_rate": SAMPLE_RATE, "start_stage": "wake_word",
"timeout": 1, "end_stage": "tts",
}, "input": {
} "sample_rate": SAMPLE_RATE,
) "timeout": 1,
},
}
)
# result # result
msg = await client.receive_json() msg = await client.receive_json()
assert msg["success"], msg assert msg["success"], msg
# run start # run start
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "run-start" assert msg["event"]["type"] == "run-start"
msg["event"]["data"]["pipeline"] = ANY msg["event"]["data"]["pipeline"] = ANY
assert msg["event"]["data"] == snapshot assert msg["event"]["data"] == snapshot
events.append(msg["event"]) events.append(msg["event"])
# wake_word # wake_word
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "wake_word-start" assert msg["event"]["type"] == "wake_word-start"
assert msg["event"]["data"] == snapshot assert msg["event"]["data"] == snapshot
events.append(msg["event"]) events.append(msg["event"])
# 2 seconds of silence # 2 seconds of silence
await client.send_bytes(bytes([1]) + bytes(2 * BYTES_ONE_SECOND)) await client.send_bytes(bytes([1]) + bytes(2 * BYTES_ONE_SECOND))
# Time out error # Time out error
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "error" assert msg["event"]["type"] == "error"
assert msg["event"]["data"] == snapshot assert msg["event"]["data"] == snapshot
events.append(msg["event"]) events.append(msg["event"])
# run end # run end
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "run-end" assert msg["event"]["type"] == "run-end"
assert msg["event"]["data"] == snapshot assert msg["event"]["data"] == snapshot
events.append(msg["event"]) events.append(msg["event"])
async def test_audio_pipeline_with_wake_word_no_timeout( async def test_audio_pipeline_with_wake_word_no_timeout(
@ -265,98 +271,101 @@ async def test_audio_pipeline_with_wake_word_no_timeout(
events = [] events = []
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
await client.send_json_auto_id( with patch(
{ "homeassistant.components.tts.secrets.token_urlsafe", return_value="test_token"
"type": "assist_pipeline/run", ):
"start_stage": "wake_word", await client.send_json_auto_id(
"end_stage": "tts", {
"input": {"sample_rate": SAMPLE_RATE, "timeout": 0, "no_vad": True}, "type": "assist_pipeline/run",
} "start_stage": "wake_word",
) "end_stage": "tts",
"input": {"sample_rate": SAMPLE_RATE, "timeout": 0, "no_vad": True},
}
)
# result # result
msg = await client.receive_json()
assert msg["success"], msg
# run start
msg = await client.receive_json()
assert msg["event"]["type"] == "run-start"
msg["event"]["data"]["pipeline"] = ANY
assert msg["event"]["data"] == snapshot
handler_id = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"]
events.append(msg["event"])
# wake_word
msg = await client.receive_json()
assert msg["event"]["type"] == "wake_word-start"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# "audio"
await client.send_bytes(bytes([handler_id]) + make_10ms_chunk(b"wake word"))
async with asyncio.timeout(1):
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "wake_word-end" assert msg["success"], msg
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# stt # run start
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "stt-start" assert msg["event"]["type"] == "run-start"
assert msg["event"]["data"] == snapshot msg["event"]["data"]["pipeline"] = ANY
events.append(msg["event"]) assert msg["event"]["data"] == snapshot
handler_id = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"]
events.append(msg["event"])
# End of audio stream (handler id + empty payload) # wake_word
await client.send_bytes(bytes([handler_id])) msg = await client.receive_json()
assert msg["event"]["type"] == "wake_word-start"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
msg = await client.receive_json() # "audio"
assert msg["event"]["type"] == "stt-end" await client.send_bytes(bytes([handler_id]) + make_10ms_chunk(b"wake word"))
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# intent async with asyncio.timeout(1):
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "intent-start" assert msg["event"]["type"] == "wake_word-end"
assert msg["event"]["data"] == snapshot assert msg["event"]["data"] == snapshot
events.append(msg["event"]) events.append(msg["event"])
msg = await client.receive_json() # stt
assert msg["event"]["type"] == "intent-end" msg = await client.receive_json()
assert msg["event"]["data"] == snapshot assert msg["event"]["type"] == "stt-start"
events.append(msg["event"]) assert msg["event"]["data"] == snapshot
events.append(msg["event"])
# text-to-speech # End of audio stream (handler id + empty payload)
msg = await client.receive_json() await client.send_bytes(bytes([handler_id]))
assert msg["event"]["type"] == "tts-start"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "tts-end" assert msg["event"]["type"] == "stt-end"
assert msg["event"]["data"] == snapshot assert msg["event"]["data"] == snapshot
events.append(msg["event"]) events.append(msg["event"])
# run end # intent
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "run-end" assert msg["event"]["type"] == "intent-start"
assert msg["event"]["data"] == snapshot assert msg["event"]["data"] == snapshot
events.append(msg["event"]) events.append(msg["event"])
pipeline_data: PipelineData = hass.data[DOMAIN] msg = await client.receive_json()
pipeline_id = list(pipeline_data.pipeline_debug)[0] assert msg["event"]["type"] == "intent-end"
pipeline_run_id = list(pipeline_data.pipeline_debug[pipeline_id])[0] assert msg["event"]["data"] == snapshot
events.append(msg["event"])
await client.send_json_auto_id( # text-to-speech
{ msg = await client.receive_json()
"type": "assist_pipeline/pipeline_debug/get", assert msg["event"]["type"] == "tts-start"
"pipeline_id": pipeline_id, assert msg["event"]["data"] == snapshot
"pipeline_run_id": pipeline_run_id, events.append(msg["event"])
}
) msg = await client.receive_json()
msg = await client.receive_json() assert msg["event"]["type"] == "tts-end"
assert msg["success"] assert msg["event"]["data"] == snapshot
assert msg["result"] == {"events": events} events.append(msg["event"])
# run end
msg = await client.receive_json()
assert msg["event"]["type"] == "run-end"
assert msg["event"]["data"] == snapshot
events.append(msg["event"])
pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_id = list(pipeline_data.pipeline_debug)[0]
pipeline_run_id = list(pipeline_data.pipeline_debug[pipeline_id])[0]
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline_debug/get",
"pipeline_id": pipeline_id,
"pipeline_run_id": pipeline_run_id,
}
)
msg = await client.receive_json()
assert msg["success"]
assert msg["result"] == {"events": events}
async def test_audio_pipeline_no_wake_word_engine( async def test_audio_pipeline_no_wake_word_engine(
@ -1540,99 +1549,102 @@ async def test_audio_pipeline_debug(
events = [] events = []
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
await client.send_json_auto_id( with patch(
{ "homeassistant.components.tts.secrets.token_urlsafe", return_value="test_token"
"type": "assist_pipeline/run", ):
"start_stage": "stt", await client.send_json_auto_id(
"end_stage": "tts", {
"input": { "type": "assist_pipeline/run",
"sample_rate": 44100, "start_stage": "stt",
}, "end_stage": "tts",
} "input": {
) "sample_rate": 44100,
},
}
)
# result # result
msg = await client.receive_json() msg = await client.receive_json()
assert msg["success"] assert msg["success"]
# run start # run start
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "run-start" assert msg["event"]["type"] == "run-start"
msg["event"]["data"]["pipeline"] = ANY msg["event"]["data"]["pipeline"] = ANY
assert msg["event"]["data"] == snapshot assert msg["event"]["data"] == snapshot
handler_id = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"] handler_id = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"]
events.append(msg["event"]) events.append(msg["event"])
# stt # stt
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "stt-start" assert msg["event"]["type"] == "stt-start"
assert msg["event"]["data"] == snapshot assert msg["event"]["data"] == snapshot
events.append(msg["event"]) events.append(msg["event"])
# End of audio stream (handler id + empty payload) # End of audio stream (handler id + empty payload)
await client.send_bytes(bytes([handler_id])) await client.send_bytes(bytes([handler_id]))
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "stt-end" assert msg["event"]["type"] == "stt-end"
assert msg["event"]["data"] == snapshot assert msg["event"]["data"] == snapshot
events.append(msg["event"]) events.append(msg["event"])
# intent # intent
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "intent-start" assert msg["event"]["type"] == "intent-start"
assert msg["event"]["data"] == snapshot assert msg["event"]["data"] == snapshot
events.append(msg["event"]) events.append(msg["event"])
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "intent-end" assert msg["event"]["type"] == "intent-end"
assert msg["event"]["data"] == snapshot assert msg["event"]["data"] == snapshot
events.append(msg["event"]) events.append(msg["event"])
# text-to-speech # text-to-speech
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "tts-start" assert msg["event"]["type"] == "tts-start"
assert msg["event"]["data"] == snapshot assert msg["event"]["data"] == snapshot
events.append(msg["event"]) events.append(msg["event"])
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "tts-end" assert msg["event"]["type"] == "tts-end"
assert msg["event"]["data"] == snapshot assert msg["event"]["data"] == snapshot
events.append(msg["event"]) events.append(msg["event"])
# run end # run end
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "run-end" assert msg["event"]["type"] == "run-end"
assert msg["event"]["data"] == snapshot assert msg["event"]["data"] == snapshot
events.append(msg["event"]) events.append(msg["event"])
# Get the id of the pipeline # Get the id of the pipeline
await client.send_json_auto_id({"type": "assist_pipeline/pipeline/list"}) await client.send_json_auto_id({"type": "assist_pipeline/pipeline/list"})
msg = await client.receive_json() msg = await client.receive_json()
assert msg["success"] assert msg["success"]
assert len(msg["result"]["pipelines"]) == 1 assert len(msg["result"]["pipelines"]) == 1
pipeline_id = msg["result"]["pipelines"][0]["id"] pipeline_id = msg["result"]["pipelines"][0]["id"]
# Get the id for the run # Get the id for the run
await client.send_json_auto_id( await client.send_json_auto_id(
{"type": "assist_pipeline/pipeline_debug/list", "pipeline_id": pipeline_id} {"type": "assist_pipeline/pipeline_debug/list", "pipeline_id": pipeline_id}
) )
msg = await client.receive_json() msg = await client.receive_json()
assert msg["success"] assert msg["success"]
assert msg["result"] == {"pipeline_runs": [ANY]} assert msg["result"] == {"pipeline_runs": [ANY]}
pipeline_run_id = msg["result"]["pipeline_runs"][0]["pipeline_run_id"] pipeline_run_id = msg["result"]["pipeline_runs"][0]["pipeline_run_id"]
await client.send_json_auto_id( await client.send_json_auto_id(
{ {
"type": "assist_pipeline/pipeline_debug/get", "type": "assist_pipeline/pipeline_debug/get",
"pipeline_id": pipeline_id, "pipeline_id": pipeline_id,
"pipeline_run_id": pipeline_run_id, "pipeline_run_id": pipeline_run_id,
} }
) )
msg = await client.receive_json() msg = await client.receive_json()
assert msg["success"] assert msg["success"]
assert msg["result"] == {"events": events} assert msg["result"] == {"events": events}
async def test_pipeline_debug_list_runs_wrong_pipeline( async def test_pipeline_debug_list_runs_wrong_pipeline(
@ -1787,94 +1799,97 @@ async def test_audio_pipeline_with_enhancements(
events = [] events = []
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
await client.send_json_auto_id( with patch(
{ "homeassistant.components.tts.secrets.token_urlsafe", return_value="test_token"
"type": "assist_pipeline/run", ):
"start_stage": "stt", await client.send_json_auto_id(
"end_stage": "tts", {
"input": { "type": "assist_pipeline/run",
"sample_rate": SAMPLE_RATE, "start_stage": "stt",
# Enhancements "end_stage": "tts",
"noise_suppression_level": 2, "input": {
"auto_gain_dbfs": 15, "sample_rate": SAMPLE_RATE,
"volume_multiplier": 2.0, # Enhancements
}, "noise_suppression_level": 2,
} "auto_gain_dbfs": 15,
) "volume_multiplier": 2.0,
},
}
)
# result # result
msg = await client.receive_json() msg = await client.receive_json()
assert msg["success"] assert msg["success"]
# run start # run start
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "run-start" assert msg["event"]["type"] == "run-start"
msg["event"]["data"]["pipeline"] = ANY msg["event"]["data"]["pipeline"] = ANY
assert msg["event"]["data"] == snapshot assert msg["event"]["data"] == snapshot
handler_id = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"] handler_id = msg["event"]["data"]["runner_data"]["stt_binary_handler_id"]
events.append(msg["event"]) events.append(msg["event"])
# stt # stt
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "stt-start" assert msg["event"]["type"] == "stt-start"
assert msg["event"]["data"] == snapshot assert msg["event"]["data"] == snapshot
events.append(msg["event"]) events.append(msg["event"])
# One second of silence. # One second of silence.
# This will pass through the audio enhancement pipeline, but we don't test # This will pass through the audio enhancement pipeline, but we don't test
# the actual output. # the actual output.
await client.send_bytes(bytes([handler_id]) + bytes(BYTES_ONE_SECOND)) await client.send_bytes(bytes([handler_id]) + bytes(BYTES_ONE_SECOND))
# End of audio stream (handler id + empty payload) # End of audio stream (handler id + empty payload)
await client.send_bytes(bytes([handler_id])) await client.send_bytes(bytes([handler_id]))
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "stt-end" assert msg["event"]["type"] == "stt-end"
assert msg["event"]["data"] == snapshot assert msg["event"]["data"] == snapshot
events.append(msg["event"]) events.append(msg["event"])
# intent # intent
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "intent-start" assert msg["event"]["type"] == "intent-start"
assert msg["event"]["data"] == snapshot assert msg["event"]["data"] == snapshot
events.append(msg["event"]) events.append(msg["event"])
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "intent-end" assert msg["event"]["type"] == "intent-end"
assert msg["event"]["data"] == snapshot assert msg["event"]["data"] == snapshot
events.append(msg["event"]) events.append(msg["event"])
# text-to-speech # text-to-speech
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "tts-start" assert msg["event"]["type"] == "tts-start"
assert msg["event"]["data"] == snapshot assert msg["event"]["data"] == snapshot
events.append(msg["event"]) events.append(msg["event"])
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "tts-end" assert msg["event"]["type"] == "tts-end"
assert msg["event"]["data"] == snapshot assert msg["event"]["data"] == snapshot
events.append(msg["event"]) events.append(msg["event"])
# run end # run end
msg = await client.receive_json() msg = await client.receive_json()
assert msg["event"]["type"] == "run-end" assert msg["event"]["type"] == "run-end"
assert msg["event"]["data"] == snapshot assert msg["event"]["data"] == snapshot
events.append(msg["event"]) events.append(msg["event"])
pipeline_data: PipelineData = hass.data[DOMAIN] pipeline_data: PipelineData = hass.data[DOMAIN]
pipeline_id = list(pipeline_data.pipeline_debug)[0] pipeline_id = list(pipeline_data.pipeline_debug)[0]
pipeline_run_id = list(pipeline_data.pipeline_debug[pipeline_id])[0] pipeline_run_id = list(pipeline_data.pipeline_debug[pipeline_id])[0]
await client.send_json_auto_id( await client.send_json_auto_id(
{ {
"type": "assist_pipeline/pipeline_debug/get", "type": "assist_pipeline/pipeline_debug/get",
"pipeline_id": pipeline_id, "pipeline_id": pipeline_id,
"pipeline_run_id": pipeline_run_id, "pipeline_run_id": pipeline_run_id,
} }
) )
msg = await client.receive_json() msg = await client.receive_json()
assert msg["success"] assert msg["success"]
assert msg["result"] == {"events": events} assert msg["result"] == {"events": events}
async def test_wake_word_cooldown_same_id( async def test_wake_word_cooldown_same_id(

View File

@ -227,25 +227,21 @@ async def test_get_tts_audio(
await on_start_callback() await on_start_callback()
client = await hass_client() client = await hass_client()
url = "/api/tts_get_url" with patch(
data |= {"message": "There is someone at the door."} "homeassistant.components.tts.secrets.token_urlsafe", return_value="test_token"
):
url = "/api/tts_get_url"
data |= {"message": "There is someone at the door."}
req = await client.post(url, json=data) req = await client.post(url, json=data)
assert req.status == HTTPStatus.OK assert req.status == HTTPStatus.OK
response = await req.json() response = await req.json()
assert response == { assert response == {
"url": ( "url": ("http://example.local:8123/api/tts_proxy/test_token.mp3"),
"http://example.local:8123/api/tts_proxy/" "path": ("/api/tts_proxy/test_token.mp3"),
"42f18378fd4393d18c8dd11d03fa9563c1e54491" }
f"_en-us_6e8b81ac47_{expected_url_suffix}.mp3" await hass.async_block_till_done()
),
"path": (
"/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491"
f"_en-us_6e8b81ac47_{expected_url_suffix}.mp3"
),
}
await hass.async_block_till_done()
assert mock_process_tts.call_count == 1 assert mock_process_tts.call_count == 1
assert mock_process_tts.call_args is not None assert mock_process_tts.call_args is not None
@ -280,25 +276,21 @@ async def test_get_tts_audio_logged_out(
await hass.async_block_till_done() await hass.async_block_till_done()
client = await hass_client() client = await hass_client()
url = "/api/tts_get_url" with patch(
data |= {"message": "There is someone at the door."} "homeassistant.components.tts.secrets.token_urlsafe", return_value="test_token"
):
url = "/api/tts_get_url"
data |= {"message": "There is someone at the door."}
req = await client.post(url, json=data) req = await client.post(url, json=data)
assert req.status == HTTPStatus.OK assert req.status == HTTPStatus.OK
response = await req.json() response = await req.json()
assert response == { assert response == {
"url": ( "url": ("http://example.local:8123/api/tts_proxy/test_token.mp3"),
"http://example.local:8123/api/tts_proxy/" "path": ("/api/tts_proxy/test_token.mp3"),
"42f18378fd4393d18c8dd11d03fa9563c1e54491" }
f"_en-us_6e8b81ac47_{expected_url_suffix}.mp3" await hass.async_block_till_done()
),
"path": (
"/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491"
f"_en-us_6e8b81ac47_{expected_url_suffix}.mp3"
),
}
await hass.async_block_till_done()
assert mock_process_tts.call_count == 1 assert mock_process_tts.call_count == 1
assert mock_process_tts.call_args is not None assert mock_process_tts.call_args is not None
@ -342,28 +334,24 @@ async def test_tts_entity(
assert state assert state
assert state.state == STATE_UNKNOWN assert state.state == STATE_UNKNOWN
url = "/api/tts_get_url" with patch(
data = { "homeassistant.components.tts.secrets.token_urlsafe", return_value="test_token"
"engine_id": entity_id, ):
"message": "There is someone at the door.", url = "/api/tts_get_url"
} data = {
"engine_id": entity_id,
"message": "There is someone at the door.",
}
req = await client.post(url, json=data) req = await client.post(url, json=data)
assert req.status == HTTPStatus.OK assert req.status == HTTPStatus.OK
response = await req.json() response = await req.json()
assert response == { assert response == {
"url": ( "url": ("http://example.local:8123/api/tts_proxy/test_token.mp3"),
"http://example.local:8123/api/tts_proxy/" "path": ("/api/tts_proxy/test_token.mp3"),
"42f18378fd4393d18c8dd11d03fa9563c1e54491" }
f"_en-us_6e8b81ac47_{entity_id}.mp3" await hass.async_block_till_done()
),
"path": (
"/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491"
f"_en-us_6e8b81ac47_{entity_id}.mp3"
),
}
await hass.async_block_till_done()
assert mock_process_tts.call_count == 1 assert mock_process_tts.call_count == 1
assert mock_process_tts.call_args is not None assert mock_process_tts.call_args is not None
@ -482,29 +470,25 @@ async def test_deprecated_voice(
client = await hass_client() client = await hass_client()
# Test with non deprecated voice. # Test with non deprecated voice.
url = "/api/tts_get_url" with patch(
data |= { "homeassistant.components.tts.secrets.token_urlsafe", return_value="test_token"
"message": "There is someone at the door.", ):
"language": language, url = "/api/tts_get_url"
"options": {"voice": replacement_voice}, data |= {
} "message": "There is someone at the door.",
"language": language,
"options": {"voice": replacement_voice},
}
req = await client.post(url, json=data) req = await client.post(url, json=data)
assert req.status == HTTPStatus.OK assert req.status == HTTPStatus.OK
response = await req.json() response = await req.json()
assert response == { assert response == {
"url": ( "url": ("http://example.local:8123/api/tts_proxy/test_token.mp3"),
"http://example.local:8123/api/tts_proxy/" "path": ("/api/tts_proxy/test_token.mp3"),
"42f18378fd4393d18c8dd11d03fa9563c1e54491" }
f"_{language.lower()}_87567e3e29_{expected_url_suffix}.mp3" await hass.async_block_till_done()
),
"path": (
"/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491"
f"_{language.lower()}_87567e3e29_{expected_url_suffix}.mp3"
),
}
await hass.async_block_till_done()
assert mock_process_tts.call_count == 1 assert mock_process_tts.call_count == 1
assert mock_process_tts.call_args is not None assert mock_process_tts.call_args is not None
@ -522,22 +506,18 @@ async def test_deprecated_voice(
# Test with deprecated voice. # Test with deprecated voice.
data["options"] = {"voice": deprecated_voice} data["options"] = {"voice": deprecated_voice}
req = await client.post(url, json=data) with patch(
assert req.status == HTTPStatus.OK "homeassistant.components.tts.secrets.token_urlsafe", return_value="test_token"
response = await req.json() ):
req = await client.post(url, json=data)
assert req.status == HTTPStatus.OK
response = await req.json()
assert response == { assert response == {
"url": ( "url": ("http://example.local:8123/api/tts_proxy/test_token.mp3"),
"http://example.local:8123/api/tts_proxy/" "path": ("/api/tts_proxy/test_token.mp3"),
"42f18378fd4393d18c8dd11d03fa9563c1e54491" }
f"_{language.lower()}_13646b7d32_{expected_url_suffix}.mp3" await hass.async_block_till_done()
),
"path": (
"/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491"
f"_{language.lower()}_13646b7d32_{expected_url_suffix}.mp3"
),
}
await hass.async_block_till_done()
issue_id = f"deprecated_voice_{deprecated_voice}" issue_id = f"deprecated_voice_{deprecated_voice}"
@ -631,28 +611,24 @@ async def test_deprecated_gender(
client = await hass_client() client = await hass_client()
# Test without deprecated gender option. # Test without deprecated gender option.
url = "/api/tts_get_url" with patch(
data |= { "homeassistant.components.tts.secrets.token_urlsafe", return_value="test_token"
"message": "There is someone at the door.", ):
"language": language, url = "/api/tts_get_url"
} data |= {
"message": "There is someone at the door.",
"language": language,
}
req = await client.post(url, json=data) req = await client.post(url, json=data)
assert req.status == HTTPStatus.OK assert req.status == HTTPStatus.OK
response = await req.json() response = await req.json()
assert response == { assert response == {
"url": ( "url": ("http://example.local:8123/api/tts_proxy/test_token.mp3"),
"http://example.local:8123/api/tts_proxy/" "path": ("/api/tts_proxy/test_token.mp3"),
"42f18378fd4393d18c8dd11d03fa9563c1e54491" }
f"_{language.lower()}_6e8b81ac47_{expected_url_suffix}.mp3" await hass.async_block_till_done()
),
"path": (
"/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491"
f"_{language.lower()}_6e8b81ac47_{expected_url_suffix}.mp3"
),
}
await hass.async_block_till_done()
assert mock_process_tts.call_count == 1 assert mock_process_tts.call_count == 1
assert mock_process_tts.call_args is not None assert mock_process_tts.call_args is not None
@ -667,22 +643,18 @@ async def test_deprecated_gender(
# Test with deprecated gender option. # Test with deprecated gender option.
data["options"] = {"gender": gender_option} data["options"] = {"gender": gender_option}
req = await client.post(url, json=data) with patch(
assert req.status == HTTPStatus.OK "homeassistant.components.tts.secrets.token_urlsafe", return_value="test_token"
response = await req.json() ):
req = await client.post(url, json=data)
assert req.status == HTTPStatus.OK
response = await req.json()
assert response == { assert response == {
"url": ( "url": ("http://example.local:8123/api/tts_proxy/test_token.mp3"),
"http://example.local:8123/api/tts_proxy/" "path": ("/api/tts_proxy/test_token.mp3"),
"42f18378fd4393d18c8dd11d03fa9563c1e54491" }
f"_{language.lower()}_dd0e95eb04_{expected_url_suffix}.mp3" await hass.async_block_till_done()
),
"path": (
"/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491"
f"_{language.lower()}_dd0e95eb04_{expected_url_suffix}.mp3"
),
}
await hass.async_block_till_done()
issue_id = "deprecated_gender" issue_id = "deprecated_gender"

View File

@ -204,18 +204,20 @@ async def test_service(
blocking=True, blocking=True,
) )
assert len(calls) == 1 with patch(
assert calls[0].data[ATTR_MEDIA_ANNOUNCE] is True "homeassistant.components.tts.secrets.token_urlsafe", return_value="test_token"
assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MediaType.MUSIC ):
assert await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) == ( assert len(calls) == 1
"/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491" assert calls[0].data[ATTR_MEDIA_ANNOUNCE] is True
f"_en-us_-_{expected_url_suffix}.mp3" assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MediaType.MUSIC
) assert await get_media_source_url(
await hass.async_block_till_done() hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]
assert ( ) == ("/api/tts_proxy/test_token.mp3")
mock_tts_cache_dir await hass.async_block_till_done()
/ f"42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_{expected_url_suffix}.mp3" assert (
).is_file() mock_tts_cache_dir
/ f"42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_{expected_url_suffix}.mp3"
).is_file()
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -266,17 +268,20 @@ async def test_service_default_language(
) )
assert len(calls) == 1 assert len(calls) == 1
assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MediaType.MUSIC assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MediaType.MUSIC
assert await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) == (
"/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491" with patch(
f"_de-de_-_{expected_url_suffix}.mp3" "homeassistant.components.tts.secrets.token_urlsafe", return_value="test_token"
) ):
await hass.async_block_till_done() assert await get_media_source_url(
assert ( hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]
mock_tts_cache_dir ) == ("/api/tts_proxy/test_token.mp3")
/ ( await hass.async_block_till_done()
f"42f18378fd4393d18c8dd11d03fa9563c1e54491_de-de_-_{expected_url_suffix}.mp3" assert (
) mock_tts_cache_dir
).is_file() / (
f"42f18378fd4393d18c8dd11d03fa9563c1e54491_de-de_-_{expected_url_suffix}.mp3"
)
).is_file()
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -327,15 +332,18 @@ async def test_service_default_special_language(
) )
assert len(calls) == 1 assert len(calls) == 1
assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MediaType.MUSIC assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MediaType.MUSIC
assert await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) == (
"/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491" with patch(
f"_en-us_-_{expected_url_suffix}.mp3" "homeassistant.components.tts.secrets.token_urlsafe", return_value="test_token"
) ):
await hass.async_block_till_done() assert await get_media_source_url(
assert ( hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]
mock_tts_cache_dir ) == ("/api/tts_proxy/test_token.mp3")
/ f"42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_{expected_url_suffix}.mp3" await hass.async_block_till_done()
).is_file() assert (
mock_tts_cache_dir
/ f"42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_{expected_url_suffix}.mp3"
).is_file()
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -384,15 +392,18 @@ async def test_service_language(
) )
assert len(calls) == 1 assert len(calls) == 1
assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MediaType.MUSIC assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MediaType.MUSIC
assert await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) == (
"/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491" with patch(
f"_de-de_-_{expected_url_suffix}.mp3" "homeassistant.components.tts.secrets.token_urlsafe", return_value="test_token"
) ):
await hass.async_block_till_done() assert await get_media_source_url(
assert ( hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]
mock_tts_cache_dir ) == ("/api/tts_proxy/test_token.mp3")
/ f"42f18378fd4393d18c8dd11d03fa9563c1e54491_de-de_-_{expected_url_suffix}.mp3" await hass.async_block_till_done()
).is_file() assert (
mock_tts_cache_dir
/ f"42f18378fd4393d18c8dd11d03fa9563c1e54491_de-de_-_{expected_url_suffix}.mp3"
).is_file()
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -497,18 +508,21 @@ async def test_service_options(
assert len(calls) == 1 assert len(calls) == 1
assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MediaType.MUSIC assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MediaType.MUSIC
assert await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) == (
"/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491" with patch(
f"_de-de_{opt_hash}_{expected_url_suffix}.mp3" "homeassistant.components.tts.secrets.token_urlsafe", return_value="test_token"
) ):
await hass.async_block_till_done() assert await get_media_source_url(
assert ( hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]
mock_tts_cache_dir ) == ("/api/tts_proxy/test_token.mp3")
/ ( await hass.async_block_till_done()
"42f18378fd4393d18c8dd11d03fa9563c1e54491" assert (
f"_de-de_{opt_hash}_{expected_url_suffix}.mp3" mock_tts_cache_dir
) / (
).is_file() "42f18378fd4393d18c8dd11d03fa9563c1e54491"
f"_de-de_{opt_hash}_{expected_url_suffix}.mp3"
)
).is_file()
class MockProviderWithDefaults(MockTTSProvider): class MockProviderWithDefaults(MockTTSProvider):
@ -578,18 +592,21 @@ async def test_service_default_options(
assert len(calls) == 1 assert len(calls) == 1
assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MediaType.MUSIC assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MediaType.MUSIC
assert await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) == (
"/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491" with patch(
f"_de-de_{opt_hash}_{expected_url_suffix}.mp3" "homeassistant.components.tts.secrets.token_urlsafe", return_value="test_token"
) ):
await hass.async_block_till_done() assert await get_media_source_url(
assert ( hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]
mock_tts_cache_dir ) == ("/api/tts_proxy/test_token.mp3")
/ ( await hass.async_block_till_done()
"42f18378fd4393d18c8dd11d03fa9563c1e54491" assert (
f"_de-de_{opt_hash}_{expected_url_suffix}.mp3" mock_tts_cache_dir
) / (
).is_file() "42f18378fd4393d18c8dd11d03fa9563c1e54491"
f"_de-de_{opt_hash}_{expected_url_suffix}.mp3"
)
).is_file()
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -649,18 +666,21 @@ async def test_merge_default_service_options(
assert len(calls) == 1 assert len(calls) == 1
assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MediaType.MUSIC assert calls[0].data[ATTR_MEDIA_CONTENT_TYPE] == MediaType.MUSIC
assert await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) == (
"/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491" with patch(
f"_de-de_{opt_hash}_{expected_url_suffix}.mp3" "homeassistant.components.tts.secrets.token_urlsafe", return_value="test_token"
) ):
await hass.async_block_till_done() assert await get_media_source_url(
assert ( hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]
mock_tts_cache_dir ) == ("/api/tts_proxy/test_token.mp3")
/ ( await hass.async_block_till_done()
"42f18378fd4393d18c8dd11d03fa9563c1e54491" assert (
f"_de-de_{opt_hash}_{expected_url_suffix}.mp3" mock_tts_cache_dir
) / (
).is_file() "42f18378fd4393d18c8dd11d03fa9563c1e54491"
f"_de-de_{opt_hash}_{expected_url_suffix}.mp3"
)
).is_file()
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -1065,10 +1085,14 @@ async def test_setup_legacy_cache_dir(
) )
assert len(calls) == 1 assert len(calls) == 1
assert await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) == (
"/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_test.mp3" with patch(
) "homeassistant.components.tts.secrets.token_urlsafe", return_value="test_token"
await hass.async_block_till_done() ):
assert await get_media_source_url(
hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]
) == ("/api/tts_proxy/test_token.mp3")
await hass.async_block_till_done()
@pytest.mark.parametrize("mock_tts_entity", [MockEntityBoom(DEFAULT_LANG)]) @pytest.mark.parametrize("mock_tts_entity", [MockEntityBoom(DEFAULT_LANG)])
@ -1100,10 +1124,13 @@ async def test_setup_cache_dir(
) )
assert len(calls) == 1 assert len(calls) == 1
assert await get_media_source_url(hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]) == ( with patch(
"/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_tts.test.mp3" "homeassistant.components.tts.secrets.token_urlsafe", return_value="test_token"
) ):
await hass.async_block_till_done() assert await get_media_source_url(
hass, calls[0].data[ATTR_MEDIA_CONTENT_ID]
) == ("/api/tts_proxy/test_token.mp3")
await hass.async_block_till_done()
class MockProviderEmpty(MockTTSProvider): class MockProviderEmpty(MockTTSProvider):
@ -1176,13 +1203,13 @@ async def test_service_get_tts_error(
) )
async def test_load_cache_legacy_retrieve_without_mem_cache( async def test_legacy_cannot_retrieve_without_token(
hass: HomeAssistant, hass: HomeAssistant,
mock_provider: MockTTSProvider, mock_provider: MockTTSProvider,
mock_tts_cache_dir: Path, mock_tts_cache_dir: Path,
hass_client: ClientSessionGenerator, hass_client: ClientSessionGenerator,
) -> None: ) -> None:
"""Set up component and load cache and get without mem cache.""" """Verify that a TTS cannot be retrieved by filename directly."""
tts_data = b"" tts_data = b""
cache_file = ( cache_file = (
mock_tts_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3" mock_tts_cache_dir / "42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3"
@ -1196,17 +1223,16 @@ async def test_load_cache_legacy_retrieve_without_mem_cache(
url = "/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3" url = "/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_en_-_test.mp3"
req = await client.get(url) req = await client.get(url)
assert req.status == HTTPStatus.OK assert req.status == HTTPStatus.NOT_FOUND
assert await req.read() == tts_data
async def test_load_cache_retrieve_without_mem_cache( async def test_cannot_retrieve_without_token(
hass: HomeAssistant, hass: HomeAssistant,
mock_tts_entity: MockTTSEntity, mock_tts_entity: MockTTSEntity,
mock_tts_cache_dir: Path, mock_tts_cache_dir: Path,
hass_client: ClientSessionGenerator, hass_client: ClientSessionGenerator,
) -> None: ) -> None:
"""Set up component and load cache and get without mem cache.""" """Verify that a TTS cannot be retrieved by filename directly."""
tts_data = b"" tts_data = b""
cache_file = mock_tts_cache_dir / ( cache_file = mock_tts_cache_dir / (
"42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_tts.test.mp3" "42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_tts.test.mp3"
@ -1220,45 +1246,37 @@ async def test_load_cache_retrieve_without_mem_cache(
url = "/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_tts.test.mp3" url = "/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491_en-us_-_tts.test.mp3"
req = await client.get(url) req = await client.get(url)
assert req.status == HTTPStatus.OK assert req.status == HTTPStatus.NOT_FOUND
assert await req.read() == tts_data
@pytest.mark.parametrize( @pytest.mark.parametrize(
("setup", "data", "expected_url_suffix"), ("setup", "data"),
[ [
("mock_setup", {"platform": "test"}, "test"), ("mock_setup", {"platform": "test"}),
("mock_setup", {"engine_id": "test"}, "test"), ("mock_setup", {"engine_id": "test"}),
("mock_config_entry_setup", {"engine_id": "tts.test"}, "tts.test"), ("mock_config_entry_setup", {"engine_id": "tts.test"}),
], ],
indirect=["setup"], indirect=["setup"],
) )
async def test_web_get_url( async def test_web_get_url(
hass_client: ClientSessionGenerator, hass_client: ClientSessionGenerator, setup: str, data: dict[str, Any]
setup: str,
data: dict[str, Any],
expected_url_suffix: str,
) -> None: ) -> None:
"""Set up a TTS platform and receive file from web.""" """Set up a TTS platform and receive file from web."""
client = await hass_client() client = await hass_client()
url = "/api/tts_get_url" with patch(
data |= {"message": "There is someone at the door."} "homeassistant.components.tts.secrets.token_urlsafe", return_value="test_token"
):
url = "/api/tts_get_url"
data |= {"message": "There is someone at the door."}
req = await client.post(url, json=data) req = await client.post(url, json=data)
assert req.status == HTTPStatus.OK assert req.status == HTTPStatus.OK
response = await req.json() response = await req.json()
assert response == { assert response == {
"url": ( "url": ("http://example.local:8123/api/tts_proxy/test_token.mp3"),
"http://example.local:8123/api/tts_proxy/" "path": ("/api/tts_proxy/test_token.mp3"),
"42f18378fd4393d18c8dd11d03fa9563c1e54491" }
f"_en-us_-_{expected_url_suffix}.mp3"
),
"path": (
"/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491"
f"_en-us_-_{expected_url_suffix}.mp3"
),
}
@pytest.mark.parametrize( @pytest.mark.parametrize(