Wyoming to use tokens instead of media source IDs for TTS (#139668)

Co-authored-by: Franck Nijhof <git@frenck.dev>
This commit is contained in:
Paulus Schoutsen 2025-04-19 06:50:41 -04:00 committed by GitHub
parent 6499ad6cdb
commit 30ab068bfe
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 56 additions and 37 deletions

View File

@ -178,7 +178,11 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
self._pipeline_ended_event.set()
self.device.set_is_active(False)
elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_START:
self.hass.add_job(self._client.write_event(Detect().event()))
self.config_entry.async_create_background_task(
self.hass,
self._client.write_event(Detect().event()),
f"{self.entity_id} {event.type}",
)
elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_END:
# Wake word detection
# Inform client of wake word detection
@ -187,46 +191,59 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
name=wake_word_output["wake_word_id"],
timestamp=wake_word_output.get("timestamp"),
)
self.hass.add_job(self._client.write_event(detection.event()))
self.config_entry.async_create_background_task(
self.hass,
self._client.write_event(detection.event()),
f"{self.entity_id} {event.type}",
)
elif event.type == assist_pipeline.PipelineEventType.STT_START:
# Speech-to-text
self.device.set_is_active(True)
if event.data:
self.hass.add_job(
self.config_entry.async_create_background_task(
self.hass,
self._client.write_event(
Transcribe(language=event.data["metadata"]["language"]).event()
)
),
f"{self.entity_id} {event.type}",
)
elif event.type == assist_pipeline.PipelineEventType.STT_VAD_START:
# User started speaking
if event.data:
self.hass.add_job(
self.config_entry.async_create_background_task(
self.hass,
self._client.write_event(
VoiceStarted(timestamp=event.data["timestamp"]).event()
)
),
f"{self.entity_id} {event.type}",
)
elif event.type == assist_pipeline.PipelineEventType.STT_VAD_END:
# User stopped speaking
if event.data:
self.hass.add_job(
self.config_entry.async_create_background_task(
self.hass,
self._client.write_event(
VoiceStopped(timestamp=event.data["timestamp"]).event()
)
),
f"{self.entity_id} {event.type}",
)
elif event.type == assist_pipeline.PipelineEventType.STT_END:
# Speech-to-text transcript
if event.data:
# Inform client of transript
stt_text = event.data["stt_output"]["text"]
self.hass.add_job(
self._client.write_event(Transcript(text=stt_text).event())
self.config_entry.async_create_background_task(
self.hass,
self._client.write_event(Transcript(text=stt_text).event()),
f"{self.entity_id} {event.type}",
)
elif event.type == assist_pipeline.PipelineEventType.TTS_START:
# Text-to-speech text
if event.data:
# Inform client of text
self.hass.add_job(
self.config_entry.async_create_background_task(
self.hass,
self._client.write_event(
Synthesize(
text=event.data["tts_input"],
@ -235,22 +252,32 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
language=event.data.get("language"),
),
).event()
)
),
f"{self.entity_id} {event.type}",
)
elif event.type == assist_pipeline.PipelineEventType.TTS_END:
# TTS stream
if event.data and (tts_output := event.data["tts_output"]):
media_id = tts_output["media_id"]
self.hass.add_job(self._stream_tts(media_id))
if (
event.data
and (tts_output := event.data["tts_output"])
and (stream := tts.async_get_stream(self.hass, tts_output["token"]))
):
self.config_entry.async_create_background_task(
self.hass,
self._stream_tts(stream),
f"{self.entity_id} {event.type}",
)
elif event.type == assist_pipeline.PipelineEventType.ERROR:
# Pipeline error
if event.data:
self.hass.add_job(
self.config_entry.async_create_background_task(
self.hass,
self._client.write_event(
Error(
text=event.data["message"], code=event.data["code"]
).event()
)
),
f"{self.entity_id} {event.type}",
)
async def async_announce(self, announcement: AssistSatelliteAnnouncement) -> None:
@ -662,13 +689,16 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
await self._client.disconnect()
self._client = None
async def _stream_tts(self, media_id: str) -> None:
async def _stream_tts(self, tts_result: tts.ResultStream) -> None:
"""Stream TTS WAV audio to satellite in chunks."""
assert self._client is not None
extension, data = await tts.async_get_media_source_audio(self.hass, media_id)
if extension != "wav":
raise ValueError(f"Cannot stream audio format to satellite: {extension}")
if tts_result.extension != "wav":
raise ValueError(
f"Cannot stream audio format to satellite: {tts_result.extension}"
)
data = b"".join([chunk async for chunk in tts_result.async_stream_result()])
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
sample_rate = wav_file.getframerate()

View File

@ -35,6 +35,7 @@ from homeassistant.setup import async_setup_component
from . import SATELLITE_INFO, WAKE_WORD_INFO, MockAsyncTcpClient
from tests.common import MockConfigEntry
from tests.components.tts.common import MockResultStream
async def setup_config_entry(hass: HomeAssistant) -> MockConfigEntry:
@ -259,10 +260,6 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None:
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.wyoming.assist_satellite.tts.async_get_media_source_audio",
return_value=("wav", get_test_wav()),
),
patch("homeassistant.components.wyoming.assist_satellite._PING_SEND_DELAY", 0),
):
entry = await setup_config_entry(hass)
@ -411,10 +408,11 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None:
assert mock_client.synthesize.voice.name == "test voice"
# Text-to-speech media
mock_tts_result_stream = MockResultStream(hass, "wav", get_test_wav())
pipeline_event_callback(
assist_pipeline.PipelineEvent(
assist_pipeline.PipelineEventType.TTS_END,
{"tts_output": {"media_id": "test media id"}},
{"tts_output": {"token": mock_tts_result_stream.token}},
)
)
async with asyncio.timeout(1):
@ -435,12 +433,6 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None:
)
assert not device.is_active
# The client should have received another ping by now
async with asyncio.timeout(1):
await mock_client.ping_event.wait()
assert mock_client.ping is not None
# Pipeline should automatically restart
async with asyncio.timeout(1):
await run_pipeline_called.wait()
@ -746,10 +738,6 @@ async def test_tts_not_wav(hass: HomeAssistant) -> None:
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
wraps=_async_pipeline_from_audio_stream,
) as mock_run_pipeline,
patch(
"homeassistant.components.wyoming.assist_satellite.tts.async_get_media_source_audio",
return_value=("mp3", bytes(1)),
),
patch(
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite._stream_tts",
_stream_tts,
@ -779,10 +767,11 @@ async def test_tts_not_wav(hass: HomeAssistant) -> None:
await mock_client.synthesize_event.wait()
# Text-to-speech media
mock_tts_result_stream = MockResultStream(hass, "mp3", bytes(1))
event_callback(
assist_pipeline.PipelineEvent(
assist_pipeline.PipelineEventType.TTS_END,
{"tts_output": {"media_id": "test media id"}},
{"tts_output": {"token": mock_tts_result_stream.token}},
)
)