mirror of
https://github.com/home-assistant/core.git
synced 2025-07-19 11:17:21 +00:00
Don't return TTS URL in Assist pipeline (#105164)
* Don't return TTS URL * Add test for empty queue
This commit is contained in:
parent
6666b796f2
commit
4c4ad9404f
@ -9,7 +9,7 @@ from dataclasses import asdict, dataclass, field
|
|||||||
from enum import StrEnum
|
from enum import StrEnum
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from queue import Queue
|
from queue import Empty, Queue
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
import time
|
import time
|
||||||
from typing import TYPE_CHECKING, Any, Final, cast
|
from typing import TYPE_CHECKING, Any, Final, cast
|
||||||
@ -1010,8 +1010,8 @@ class PipelineRun:
|
|||||||
self.tts_engine = engine
|
self.tts_engine = engine
|
||||||
self.tts_options = tts_options
|
self.tts_options = tts_options
|
||||||
|
|
||||||
async def text_to_speech(self, tts_input: str) -> str:
|
async def text_to_speech(self, tts_input: str) -> None:
|
||||||
"""Run text-to-speech portion of pipeline. Returns URL of TTS audio."""
|
"""Run text-to-speech portion of pipeline."""
|
||||||
self.process_event(
|
self.process_event(
|
||||||
PipelineEvent(
|
PipelineEvent(
|
||||||
PipelineEventType.TTS_START,
|
PipelineEventType.TTS_START,
|
||||||
@ -1058,8 +1058,6 @@ class PipelineRun:
|
|||||||
PipelineEvent(PipelineEventType.TTS_END, {"tts_output": tts_output})
|
PipelineEvent(PipelineEventType.TTS_END, {"tts_output": tts_output})
|
||||||
)
|
)
|
||||||
|
|
||||||
return tts_media.url
|
|
||||||
|
|
||||||
def _capture_chunk(self, audio_bytes: bytes | None) -> None:
|
def _capture_chunk(self, audio_bytes: bytes | None) -> None:
|
||||||
"""Forward audio chunk to various capturing mechanisms."""
|
"""Forward audio chunk to various capturing mechanisms."""
|
||||||
if self.debug_recording_queue is not None:
|
if self.debug_recording_queue is not None:
|
||||||
@ -1246,6 +1244,8 @@ def _pipeline_debug_recording_thread_proc(
|
|||||||
# Chunk of 16-bit mono audio at 16Khz
|
# Chunk of 16-bit mono audio at 16Khz
|
||||||
if wav_writer is not None:
|
if wav_writer is not None:
|
||||||
wav_writer.writeframes(message)
|
wav_writer.writeframes(message)
|
||||||
|
except Empty:
|
||||||
|
pass # occurs when pipeline has unexpected error
|
||||||
except Exception: # pylint: disable=broad-exception-caught
|
except Exception: # pylint: disable=broad-exception-caught
|
||||||
_LOGGER.exception("Unexpected error in debug recording thread")
|
_LOGGER.exception("Unexpected error in debug recording thread")
|
||||||
finally:
|
finally:
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
"""Test Voice Assistant init."""
|
"""Test Voice Assistant init."""
|
||||||
|
import asyncio
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
import itertools as it
|
import itertools as it
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -569,6 +570,69 @@ async def test_pipeline_saved_audio_write_error(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_pipeline_saved_audio_empty_queue(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_stt_provider: MockSttProvider,
|
||||||
|
mock_wake_word_provider_entity: MockWakeWordEntity,
|
||||||
|
init_supporting_components,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
|
) -> None:
|
||||||
|
"""Test that saved audio thread closes WAV file even if there's an empty queue."""
|
||||||
|
with tempfile.TemporaryDirectory() as temp_dir_str:
|
||||||
|
# Enable audio recording to temporary directory
|
||||||
|
temp_dir = Path(temp_dir_str)
|
||||||
|
assert await async_setup_component(
|
||||||
|
hass,
|
||||||
|
DOMAIN,
|
||||||
|
{DOMAIN: {CONF_DEBUG_RECORDING_DIR: temp_dir_str}},
|
||||||
|
)
|
||||||
|
|
||||||
|
def event_callback(event: assist_pipeline.PipelineEvent):
|
||||||
|
if event.type == "run-end":
|
||||||
|
# Verify WAV file exists, but contains no data
|
||||||
|
pipeline_dirs = list(temp_dir.iterdir())
|
||||||
|
run_dirs = list(pipeline_dirs[0].iterdir())
|
||||||
|
wav_path = next(run_dirs[0].iterdir())
|
||||||
|
with wave.open(str(wav_path), "rb") as wav_file:
|
||||||
|
assert wav_file.getnframes() == 0
|
||||||
|
|
||||||
|
async def audio_data():
|
||||||
|
# Force timeout in _pipeline_debug_recording_thread_proc
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
yield b"not used"
|
||||||
|
|
||||||
|
# Wrap original function to time out immediately
|
||||||
|
_pipeline_debug_recording_thread_proc = (
|
||||||
|
assist_pipeline.pipeline._pipeline_debug_recording_thread_proc
|
||||||
|
)
|
||||||
|
|
||||||
|
def proc_wrapper(run_recording_dir, queue):
|
||||||
|
_pipeline_debug_recording_thread_proc(
|
||||||
|
run_recording_dir, queue, message_timeout=0
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.assist_pipeline.pipeline._pipeline_debug_recording_thread_proc",
|
||||||
|
proc_wrapper,
|
||||||
|
):
|
||||||
|
await assist_pipeline.async_pipeline_from_audio_stream(
|
||||||
|
hass,
|
||||||
|
context=Context(),
|
||||||
|
event_callback=event_callback,
|
||||||
|
stt_metadata=stt.SpeechMetadata(
|
||||||
|
language="",
|
||||||
|
format=stt.AudioFormats.WAV,
|
||||||
|
codec=stt.AudioCodecs.PCM,
|
||||||
|
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||||
|
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||||
|
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||||
|
),
|
||||||
|
stt_stream=audio_data(),
|
||||||
|
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
|
||||||
|
end_stage=assist_pipeline.PipelineStage.STT,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_wake_word_detection_aborted(
|
async def test_wake_word_detection_aborted(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_stt_provider: MockSttProvider,
|
mock_stt_provider: MockSttProvider,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user