From 30ab068bfe21e98b5799bac1c00043059ec93dd4 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Sat, 19 Apr 2025 06:50:41 -0400 Subject: [PATCH] Wyoming to use tokens instead of media source IDs for TTS (#139668) Co-authored-by: Franck Nijhof --- .../components/wyoming/assist_satellite.py | 72 +++++++++++++------ tests/components/wyoming/test_satellite.py | 21 ++---- 2 files changed, 56 insertions(+), 37 deletions(-) diff --git a/homeassistant/components/wyoming/assist_satellite.py b/homeassistant/components/wyoming/assist_satellite.py index 5440b2bebeb..88939f0ba77 100644 --- a/homeassistant/components/wyoming/assist_satellite.py +++ b/homeassistant/components/wyoming/assist_satellite.py @@ -178,7 +178,11 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity): self._pipeline_ended_event.set() self.device.set_is_active(False) elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_START: - self.hass.add_job(self._client.write_event(Detect().event())) + self.config_entry.async_create_background_task( + self.hass, + self._client.write_event(Detect().event()), + f"{self.entity_id} {event.type}", + ) elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_END: # Wake word detection # Inform client of wake word detection @@ -187,46 +191,59 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity): name=wake_word_output["wake_word_id"], timestamp=wake_word_output.get("timestamp"), ) - self.hass.add_job(self._client.write_event(detection.event())) + self.config_entry.async_create_background_task( + self.hass, + self._client.write_event(detection.event()), + f"{self.entity_id} {event.type}", + ) elif event.type == assist_pipeline.PipelineEventType.STT_START: # Speech-to-text self.device.set_is_active(True) if event.data: - self.hass.add_job( + self.config_entry.async_create_background_task( + self.hass, self._client.write_event( Transcribe(language=event.data["metadata"]["language"]).event() - ) + ), + f"{self.entity_id} {event.type}", ) elif event.type == assist_pipeline.PipelineEventType.STT_VAD_START: # User started speaking if event.data: - self.hass.add_job( + self.config_entry.async_create_background_task( + self.hass, self._client.write_event( VoiceStarted(timestamp=event.data["timestamp"]).event() - ) + ), + f"{self.entity_id} {event.type}", ) elif event.type == assist_pipeline.PipelineEventType.STT_VAD_END: # User stopped speaking if event.data: - self.hass.add_job( + self.config_entry.async_create_background_task( + self.hass, self._client.write_event( VoiceStopped(timestamp=event.data["timestamp"]).event() - ) + ), + f"{self.entity_id} {event.type}", ) elif event.type == assist_pipeline.PipelineEventType.STT_END: # Speech-to-text transcript if event.data: # Inform client of transript stt_text = event.data["stt_output"]["text"] - self.hass.add_job( - self._client.write_event(Transcript(text=stt_text).event()) + self.config_entry.async_create_background_task( + self.hass, + self._client.write_event(Transcript(text=stt_text).event()), + f"{self.entity_id} {event.type}", ) elif event.type == assist_pipeline.PipelineEventType.TTS_START: # Text-to-speech text if event.data: # Inform client of text - self.hass.add_job( + self.config_entry.async_create_background_task( + self.hass, self._client.write_event( Synthesize( text=event.data["tts_input"], @@ -235,22 +252,32 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity): language=event.data.get("language"), ), ).event() - ) + ), + f"{self.entity_id} {event.type}", ) elif event.type == assist_pipeline.PipelineEventType.TTS_END: # TTS stream - if event.data and (tts_output := event.data["tts_output"]): - media_id = tts_output["media_id"] - self.hass.add_job(self._stream_tts(media_id)) + if ( + event.data + and (tts_output := event.data["tts_output"]) + and (stream := tts.async_get_stream(self.hass, tts_output["token"])) + ): + self.config_entry.async_create_background_task( + self.hass, + self._stream_tts(stream), + f"{self.entity_id} {event.type}", + ) elif event.type == assist_pipeline.PipelineEventType.ERROR: # Pipeline error if event.data: - self.hass.add_job( + self.config_entry.async_create_background_task( + self.hass, self._client.write_event( Error( text=event.data["message"], code=event.data["code"] ).event() - ) + ), + f"{self.entity_id} {event.type}", ) async def async_announce(self, announcement: AssistSatelliteAnnouncement) -> None: @@ -662,13 +689,16 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity): await self._client.disconnect() self._client = None - async def _stream_tts(self, media_id: str) -> None: + async def _stream_tts(self, tts_result: tts.ResultStream) -> None: """Stream TTS WAV audio to satellite in chunks.""" assert self._client is not None - extension, data = await tts.async_get_media_source_audio(self.hass, media_id) - if extension != "wav": - raise ValueError(f"Cannot stream audio format to satellite: {extension}") + if tts_result.extension != "wav": + raise ValueError( + f"Cannot stream audio format to satellite: {tts_result.extension}" + ) + + data = b"".join([chunk async for chunk in tts_result.async_stream_result()]) with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file: sample_rate = wav_file.getframerate() diff --git a/tests/components/wyoming/test_satellite.py b/tests/components/wyoming/test_satellite.py index 0e4bb3da78c..800870f4604 100644 --- a/tests/components/wyoming/test_satellite.py +++ b/tests/components/wyoming/test_satellite.py @@ -35,6 +35,7 @@ from homeassistant.setup import async_setup_component from . import SATELLITE_INFO, WAKE_WORD_INFO, MockAsyncTcpClient from tests.common import MockConfigEntry +from tests.components.tts.common import MockResultStream async def setup_config_entry(hass: HomeAssistant) -> MockConfigEntry: @@ -259,10 +260,6 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None: "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream", async_pipeline_from_audio_stream, ), - patch( - "homeassistant.components.wyoming.assist_satellite.tts.async_get_media_source_audio", - return_value=("wav", get_test_wav()), - ), patch("homeassistant.components.wyoming.assist_satellite._PING_SEND_DELAY", 0), ): entry = await setup_config_entry(hass) @@ -411,10 +408,11 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None: assert mock_client.synthesize.voice.name == "test voice" # Text-to-speech media + mock_tts_result_stream = MockResultStream(hass, "wav", get_test_wav()) pipeline_event_callback( assist_pipeline.PipelineEvent( assist_pipeline.PipelineEventType.TTS_END, - {"tts_output": {"media_id": "test media id"}}, + {"tts_output": {"token": mock_tts_result_stream.token}}, ) ) async with asyncio.timeout(1): @@ -435,12 +433,6 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None: ) assert not device.is_active - # The client should have received another ping by now - async with asyncio.timeout(1): - await mock_client.ping_event.wait() - - assert mock_client.ping is not None - # Pipeline should automatically restart async with asyncio.timeout(1): await run_pipeline_called.wait() @@ -746,10 +738,6 @@ async def test_tts_not_wav(hass: HomeAssistant) -> None: "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream", wraps=_async_pipeline_from_audio_stream, ) as mock_run_pipeline, - patch( - "homeassistant.components.wyoming.assist_satellite.tts.async_get_media_source_audio", - return_value=("mp3", bytes(1)), - ), patch( "homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite._stream_tts", _stream_tts, @@ -779,10 +767,11 @@ async def test_tts_not_wav(hass: HomeAssistant) -> None: await mock_client.synthesize_event.wait() # Text-to-speech media + mock_tts_result_stream = MockResultStream(hass, "mp3", bytes(1)) event_callback( assist_pipeline.PipelineEvent( assist_pipeline.PipelineEventType.TTS_END, - {"tts_output": {"media_id": "test media id"}}, + {"tts_output": {"token": mock_tts_result_stream.token}}, ) )