mirror of
https://github.com/home-assistant/core.git
synced 2025-07-20 11:47:06 +00:00
Wyoming to use tokens instead of media source IDs for TTS (#139668)
Co-authored-by: Franck Nijhof <git@frenck.dev>
This commit is contained in:
parent
6499ad6cdb
commit
30ab068bfe
@ -178,7 +178,11 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
|
||||
self._pipeline_ended_event.set()
|
||||
self.device.set_is_active(False)
|
||||
elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_START:
|
||||
self.hass.add_job(self._client.write_event(Detect().event()))
|
||||
self.config_entry.async_create_background_task(
|
||||
self.hass,
|
||||
self._client.write_event(Detect().event()),
|
||||
f"{self.entity_id} {event.type}",
|
||||
)
|
||||
elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_END:
|
||||
# Wake word detection
|
||||
# Inform client of wake word detection
|
||||
@ -187,46 +191,59 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
|
||||
name=wake_word_output["wake_word_id"],
|
||||
timestamp=wake_word_output.get("timestamp"),
|
||||
)
|
||||
self.hass.add_job(self._client.write_event(detection.event()))
|
||||
self.config_entry.async_create_background_task(
|
||||
self.hass,
|
||||
self._client.write_event(detection.event()),
|
||||
f"{self.entity_id} {event.type}",
|
||||
)
|
||||
elif event.type == assist_pipeline.PipelineEventType.STT_START:
|
||||
# Speech-to-text
|
||||
self.device.set_is_active(True)
|
||||
|
||||
if event.data:
|
||||
self.hass.add_job(
|
||||
self.config_entry.async_create_background_task(
|
||||
self.hass,
|
||||
self._client.write_event(
|
||||
Transcribe(language=event.data["metadata"]["language"]).event()
|
||||
)
|
||||
),
|
||||
f"{self.entity_id} {event.type}",
|
||||
)
|
||||
elif event.type == assist_pipeline.PipelineEventType.STT_VAD_START:
|
||||
# User started speaking
|
||||
if event.data:
|
||||
self.hass.add_job(
|
||||
self.config_entry.async_create_background_task(
|
||||
self.hass,
|
||||
self._client.write_event(
|
||||
VoiceStarted(timestamp=event.data["timestamp"]).event()
|
||||
)
|
||||
),
|
||||
f"{self.entity_id} {event.type}",
|
||||
)
|
||||
elif event.type == assist_pipeline.PipelineEventType.STT_VAD_END:
|
||||
# User stopped speaking
|
||||
if event.data:
|
||||
self.hass.add_job(
|
||||
self.config_entry.async_create_background_task(
|
||||
self.hass,
|
||||
self._client.write_event(
|
||||
VoiceStopped(timestamp=event.data["timestamp"]).event()
|
||||
)
|
||||
),
|
||||
f"{self.entity_id} {event.type}",
|
||||
)
|
||||
elif event.type == assist_pipeline.PipelineEventType.STT_END:
|
||||
# Speech-to-text transcript
|
||||
if event.data:
|
||||
# Inform client of transript
|
||||
stt_text = event.data["stt_output"]["text"]
|
||||
self.hass.add_job(
|
||||
self._client.write_event(Transcript(text=stt_text).event())
|
||||
self.config_entry.async_create_background_task(
|
||||
self.hass,
|
||||
self._client.write_event(Transcript(text=stt_text).event()),
|
||||
f"{self.entity_id} {event.type}",
|
||||
)
|
||||
elif event.type == assist_pipeline.PipelineEventType.TTS_START:
|
||||
# Text-to-speech text
|
||||
if event.data:
|
||||
# Inform client of text
|
||||
self.hass.add_job(
|
||||
self.config_entry.async_create_background_task(
|
||||
self.hass,
|
||||
self._client.write_event(
|
||||
Synthesize(
|
||||
text=event.data["tts_input"],
|
||||
@ -235,22 +252,32 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
|
||||
language=event.data.get("language"),
|
||||
),
|
||||
).event()
|
||||
)
|
||||
),
|
||||
f"{self.entity_id} {event.type}",
|
||||
)
|
||||
elif event.type == assist_pipeline.PipelineEventType.TTS_END:
|
||||
# TTS stream
|
||||
if event.data and (tts_output := event.data["tts_output"]):
|
||||
media_id = tts_output["media_id"]
|
||||
self.hass.add_job(self._stream_tts(media_id))
|
||||
if (
|
||||
event.data
|
||||
and (tts_output := event.data["tts_output"])
|
||||
and (stream := tts.async_get_stream(self.hass, tts_output["token"]))
|
||||
):
|
||||
self.config_entry.async_create_background_task(
|
||||
self.hass,
|
||||
self._stream_tts(stream),
|
||||
f"{self.entity_id} {event.type}",
|
||||
)
|
||||
elif event.type == assist_pipeline.PipelineEventType.ERROR:
|
||||
# Pipeline error
|
||||
if event.data:
|
||||
self.hass.add_job(
|
||||
self.config_entry.async_create_background_task(
|
||||
self.hass,
|
||||
self._client.write_event(
|
||||
Error(
|
||||
text=event.data["message"], code=event.data["code"]
|
||||
).event()
|
||||
)
|
||||
),
|
||||
f"{self.entity_id} {event.type}",
|
||||
)
|
||||
|
||||
async def async_announce(self, announcement: AssistSatelliteAnnouncement) -> None:
|
||||
@ -662,13 +689,16 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
|
||||
await self._client.disconnect()
|
||||
self._client = None
|
||||
|
||||
async def _stream_tts(self, media_id: str) -> None:
|
||||
async def _stream_tts(self, tts_result: tts.ResultStream) -> None:
|
||||
"""Stream TTS WAV audio to satellite in chunks."""
|
||||
assert self._client is not None
|
||||
|
||||
extension, data = await tts.async_get_media_source_audio(self.hass, media_id)
|
||||
if extension != "wav":
|
||||
raise ValueError(f"Cannot stream audio format to satellite: {extension}")
|
||||
if tts_result.extension != "wav":
|
||||
raise ValueError(
|
||||
f"Cannot stream audio format to satellite: {tts_result.extension}"
|
||||
)
|
||||
|
||||
data = b"".join([chunk async for chunk in tts_result.async_stream_result()])
|
||||
|
||||
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
|
||||
sample_rate = wav_file.getframerate()
|
||||
|
@ -35,6 +35,7 @@ from homeassistant.setup import async_setup_component
|
||||
from . import SATELLITE_INFO, WAKE_WORD_INFO, MockAsyncTcpClient
|
||||
|
||||
from tests.common import MockConfigEntry
|
||||
from tests.components.tts.common import MockResultStream
|
||||
|
||||
|
||||
async def setup_config_entry(hass: HomeAssistant) -> MockConfigEntry:
|
||||
@ -259,10 +260,6 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None:
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
async_pipeline_from_audio_stream,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.assist_satellite.tts.async_get_media_source_audio",
|
||||
return_value=("wav", get_test_wav()),
|
||||
),
|
||||
patch("homeassistant.components.wyoming.assist_satellite._PING_SEND_DELAY", 0),
|
||||
):
|
||||
entry = await setup_config_entry(hass)
|
||||
@ -411,10 +408,11 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None:
|
||||
assert mock_client.synthesize.voice.name == "test voice"
|
||||
|
||||
# Text-to-speech media
|
||||
mock_tts_result_stream = MockResultStream(hass, "wav", get_test_wav())
|
||||
pipeline_event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
assist_pipeline.PipelineEventType.TTS_END,
|
||||
{"tts_output": {"media_id": "test media id"}},
|
||||
{"tts_output": {"token": mock_tts_result_stream.token}},
|
||||
)
|
||||
)
|
||||
async with asyncio.timeout(1):
|
||||
@ -435,12 +433,6 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None:
|
||||
)
|
||||
assert not device.is_active
|
||||
|
||||
# The client should have received another ping by now
|
||||
async with asyncio.timeout(1):
|
||||
await mock_client.ping_event.wait()
|
||||
|
||||
assert mock_client.ping is not None
|
||||
|
||||
# Pipeline should automatically restart
|
||||
async with asyncio.timeout(1):
|
||||
await run_pipeline_called.wait()
|
||||
@ -746,10 +738,6 @@ async def test_tts_not_wav(hass: HomeAssistant) -> None:
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
wraps=_async_pipeline_from_audio_stream,
|
||||
) as mock_run_pipeline,
|
||||
patch(
|
||||
"homeassistant.components.wyoming.assist_satellite.tts.async_get_media_source_audio",
|
||||
return_value=("mp3", bytes(1)),
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite._stream_tts",
|
||||
_stream_tts,
|
||||
@ -779,10 +767,11 @@ async def test_tts_not_wav(hass: HomeAssistant) -> None:
|
||||
await mock_client.synthesize_event.wait()
|
||||
|
||||
# Text-to-speech media
|
||||
mock_tts_result_stream = MockResultStream(hass, "mp3", bytes(1))
|
||||
event_callback(
|
||||
assist_pipeline.PipelineEvent(
|
||||
assist_pipeline.PipelineEventType.TTS_END,
|
||||
{"tts_output": {"media_id": "test media id"}},
|
||||
{"tts_output": {"token": mock_tts_result_stream.token}},
|
||||
)
|
||||
)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user