Wyoming to use tokens instead of media source IDs for TTS (#139668)

Co-authored-by: Franck Nijhof <git@frenck.dev>
This commit is contained in:
Paulus Schoutsen 2025-04-19 06:50:41 -04:00 committed by GitHub
parent 6499ad6cdb
commit 30ab068bfe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 56 additions and 37 deletions

View File

@ -178,7 +178,11 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
self._pipeline_ended_event.set() self._pipeline_ended_event.set()
self.device.set_is_active(False) self.device.set_is_active(False)
elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_START: elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_START:
self.hass.add_job(self._client.write_event(Detect().event())) self.config_entry.async_create_background_task(
self.hass,
self._client.write_event(Detect().event()),
f"{self.entity_id} {event.type}",
)
elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_END: elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_END:
# Wake word detection # Wake word detection
# Inform client of wake word detection # Inform client of wake word detection
@ -187,46 +191,59 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
name=wake_word_output["wake_word_id"], name=wake_word_output["wake_word_id"],
timestamp=wake_word_output.get("timestamp"), timestamp=wake_word_output.get("timestamp"),
) )
self.hass.add_job(self._client.write_event(detection.event())) self.config_entry.async_create_background_task(
self.hass,
self._client.write_event(detection.event()),
f"{self.entity_id} {event.type}",
)
elif event.type == assist_pipeline.PipelineEventType.STT_START: elif event.type == assist_pipeline.PipelineEventType.STT_START:
# Speech-to-text # Speech-to-text
self.device.set_is_active(True) self.device.set_is_active(True)
if event.data: if event.data:
self.hass.add_job( self.config_entry.async_create_background_task(
self.hass,
self._client.write_event( self._client.write_event(
Transcribe(language=event.data["metadata"]["language"]).event() Transcribe(language=event.data["metadata"]["language"]).event()
) ),
f"{self.entity_id} {event.type}",
) )
elif event.type == assist_pipeline.PipelineEventType.STT_VAD_START: elif event.type == assist_pipeline.PipelineEventType.STT_VAD_START:
# User started speaking # User started speaking
if event.data: if event.data:
self.hass.add_job( self.config_entry.async_create_background_task(
self.hass,
self._client.write_event( self._client.write_event(
VoiceStarted(timestamp=event.data["timestamp"]).event() VoiceStarted(timestamp=event.data["timestamp"]).event()
) ),
f"{self.entity_id} {event.type}",
) )
elif event.type == assist_pipeline.PipelineEventType.STT_VAD_END: elif event.type == assist_pipeline.PipelineEventType.STT_VAD_END:
# User stopped speaking # User stopped speaking
if event.data: if event.data:
self.hass.add_job( self.config_entry.async_create_background_task(
self.hass,
self._client.write_event( self._client.write_event(
VoiceStopped(timestamp=event.data["timestamp"]).event() VoiceStopped(timestamp=event.data["timestamp"]).event()
) ),
f"{self.entity_id} {event.type}",
) )
elif event.type == assist_pipeline.PipelineEventType.STT_END: elif event.type == assist_pipeline.PipelineEventType.STT_END:
# Speech-to-text transcript # Speech-to-text transcript
if event.data: if event.data:
# Inform client of transript # Inform client of transript
stt_text = event.data["stt_output"]["text"] stt_text = event.data["stt_output"]["text"]
self.hass.add_job( self.config_entry.async_create_background_task(
self._client.write_event(Transcript(text=stt_text).event()) self.hass,
self._client.write_event(Transcript(text=stt_text).event()),
f"{self.entity_id} {event.type}",
) )
elif event.type == assist_pipeline.PipelineEventType.TTS_START: elif event.type == assist_pipeline.PipelineEventType.TTS_START:
# Text-to-speech text # Text-to-speech text
if event.data: if event.data:
# Inform client of text # Inform client of text
self.hass.add_job( self.config_entry.async_create_background_task(
self.hass,
self._client.write_event( self._client.write_event(
Synthesize( Synthesize(
text=event.data["tts_input"], text=event.data["tts_input"],
@ -235,22 +252,32 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
language=event.data.get("language"), language=event.data.get("language"),
), ),
).event() ).event()
) ),
f"{self.entity_id} {event.type}",
) )
elif event.type == assist_pipeline.PipelineEventType.TTS_END: elif event.type == assist_pipeline.PipelineEventType.TTS_END:
# TTS stream # TTS stream
if event.data and (tts_output := event.data["tts_output"]): if (
media_id = tts_output["media_id"] event.data
self.hass.add_job(self._stream_tts(media_id)) and (tts_output := event.data["tts_output"])
and (stream := tts.async_get_stream(self.hass, tts_output["token"]))
):
self.config_entry.async_create_background_task(
self.hass,
self._stream_tts(stream),
f"{self.entity_id} {event.type}",
)
elif event.type == assist_pipeline.PipelineEventType.ERROR: elif event.type == assist_pipeline.PipelineEventType.ERROR:
# Pipeline error # Pipeline error
if event.data: if event.data:
self.hass.add_job( self.config_entry.async_create_background_task(
self.hass,
self._client.write_event( self._client.write_event(
Error( Error(
text=event.data["message"], code=event.data["code"] text=event.data["message"], code=event.data["code"]
).event() ).event()
) ),
f"{self.entity_id} {event.type}",
) )
async def async_announce(self, announcement: AssistSatelliteAnnouncement) -> None: async def async_announce(self, announcement: AssistSatelliteAnnouncement) -> None:
@ -662,13 +689,16 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
await self._client.disconnect() await self._client.disconnect()
self._client = None self._client = None
async def _stream_tts(self, media_id: str) -> None: async def _stream_tts(self, tts_result: tts.ResultStream) -> None:
"""Stream TTS WAV audio to satellite in chunks.""" """Stream TTS WAV audio to satellite in chunks."""
assert self._client is not None assert self._client is not None
extension, data = await tts.async_get_media_source_audio(self.hass, media_id) if tts_result.extension != "wav":
if extension != "wav": raise ValueError(
raise ValueError(f"Cannot stream audio format to satellite: {extension}") f"Cannot stream audio format to satellite: {tts_result.extension}"
)
data = b"".join([chunk async for chunk in tts_result.async_stream_result()])
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file: with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
sample_rate = wav_file.getframerate() sample_rate = wav_file.getframerate()

View File

@ -35,6 +35,7 @@ from homeassistant.setup import async_setup_component
from . import SATELLITE_INFO, WAKE_WORD_INFO, MockAsyncTcpClient from . import SATELLITE_INFO, WAKE_WORD_INFO, MockAsyncTcpClient
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
from tests.components.tts.common import MockResultStream
async def setup_config_entry(hass: HomeAssistant) -> MockConfigEntry: async def setup_config_entry(hass: HomeAssistant) -> MockConfigEntry:
@ -259,10 +260,6 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None:
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream", "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
async_pipeline_from_audio_stream, async_pipeline_from_audio_stream,
), ),
patch(
"homeassistant.components.wyoming.assist_satellite.tts.async_get_media_source_audio",
return_value=("wav", get_test_wav()),
),
patch("homeassistant.components.wyoming.assist_satellite._PING_SEND_DELAY", 0), patch("homeassistant.components.wyoming.assist_satellite._PING_SEND_DELAY", 0),
): ):
entry = await setup_config_entry(hass) entry = await setup_config_entry(hass)
@ -411,10 +408,11 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None:
assert mock_client.synthesize.voice.name == "test voice" assert mock_client.synthesize.voice.name == "test voice"
# Text-to-speech media # Text-to-speech media
mock_tts_result_stream = MockResultStream(hass, "wav", get_test_wav())
pipeline_event_callback( pipeline_event_callback(
assist_pipeline.PipelineEvent( assist_pipeline.PipelineEvent(
assist_pipeline.PipelineEventType.TTS_END, assist_pipeline.PipelineEventType.TTS_END,
{"tts_output": {"media_id": "test media id"}}, {"tts_output": {"token": mock_tts_result_stream.token}},
) )
) )
async with asyncio.timeout(1): async with asyncio.timeout(1):
@ -435,12 +433,6 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None:
) )
assert not device.is_active assert not device.is_active
# The client should have received another ping by now
async with asyncio.timeout(1):
await mock_client.ping_event.wait()
assert mock_client.ping is not None
# Pipeline should automatically restart # Pipeline should automatically restart
async with asyncio.timeout(1): async with asyncio.timeout(1):
await run_pipeline_called.wait() await run_pipeline_called.wait()
@ -746,10 +738,6 @@ async def test_tts_not_wav(hass: HomeAssistant) -> None:
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream", "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
wraps=_async_pipeline_from_audio_stream, wraps=_async_pipeline_from_audio_stream,
) as mock_run_pipeline, ) as mock_run_pipeline,
patch(
"homeassistant.components.wyoming.assist_satellite.tts.async_get_media_source_audio",
return_value=("mp3", bytes(1)),
),
patch( patch(
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite._stream_tts", "homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite._stream_tts",
_stream_tts, _stream_tts,
@ -779,10 +767,11 @@ async def test_tts_not_wav(hass: HomeAssistant) -> None:
await mock_client.synthesize_event.wait() await mock_client.synthesize_event.wait()
# Text-to-speech media # Text-to-speech media
mock_tts_result_stream = MockResultStream(hass, "mp3", bytes(1))
event_callback( event_callback(
assist_pipeline.PipelineEvent( assist_pipeline.PipelineEvent(
assist_pipeline.PipelineEventType.TTS_END, assist_pipeline.PipelineEventType.TTS_END,
{"tts_output": {"media_id": "test media id"}}, {"tts_output": {"token": mock_tts_result_stream.token}},
) )
) )