mirror of
https://github.com/home-assistant/core.git
synced 2025-07-21 04:07:08 +00:00
Wyoming to use tokens instead of media source IDs for TTS (#139668)
Co-authored-by: Franck Nijhof <git@frenck.dev>
This commit is contained in:
parent
6499ad6cdb
commit
30ab068bfe
@ -178,7 +178,11 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
|
|||||||
self._pipeline_ended_event.set()
|
self._pipeline_ended_event.set()
|
||||||
self.device.set_is_active(False)
|
self.device.set_is_active(False)
|
||||||
elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_START:
|
elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_START:
|
||||||
self.hass.add_job(self._client.write_event(Detect().event()))
|
self.config_entry.async_create_background_task(
|
||||||
|
self.hass,
|
||||||
|
self._client.write_event(Detect().event()),
|
||||||
|
f"{self.entity_id} {event.type}",
|
||||||
|
)
|
||||||
elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_END:
|
elif event.type == assist_pipeline.PipelineEventType.WAKE_WORD_END:
|
||||||
# Wake word detection
|
# Wake word detection
|
||||||
# Inform client of wake word detection
|
# Inform client of wake word detection
|
||||||
@ -187,46 +191,59 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
|
|||||||
name=wake_word_output["wake_word_id"],
|
name=wake_word_output["wake_word_id"],
|
||||||
timestamp=wake_word_output.get("timestamp"),
|
timestamp=wake_word_output.get("timestamp"),
|
||||||
)
|
)
|
||||||
self.hass.add_job(self._client.write_event(detection.event()))
|
self.config_entry.async_create_background_task(
|
||||||
|
self.hass,
|
||||||
|
self._client.write_event(detection.event()),
|
||||||
|
f"{self.entity_id} {event.type}",
|
||||||
|
)
|
||||||
elif event.type == assist_pipeline.PipelineEventType.STT_START:
|
elif event.type == assist_pipeline.PipelineEventType.STT_START:
|
||||||
# Speech-to-text
|
# Speech-to-text
|
||||||
self.device.set_is_active(True)
|
self.device.set_is_active(True)
|
||||||
|
|
||||||
if event.data:
|
if event.data:
|
||||||
self.hass.add_job(
|
self.config_entry.async_create_background_task(
|
||||||
|
self.hass,
|
||||||
self._client.write_event(
|
self._client.write_event(
|
||||||
Transcribe(language=event.data["metadata"]["language"]).event()
|
Transcribe(language=event.data["metadata"]["language"]).event()
|
||||||
)
|
),
|
||||||
|
f"{self.entity_id} {event.type}",
|
||||||
)
|
)
|
||||||
elif event.type == assist_pipeline.PipelineEventType.STT_VAD_START:
|
elif event.type == assist_pipeline.PipelineEventType.STT_VAD_START:
|
||||||
# User started speaking
|
# User started speaking
|
||||||
if event.data:
|
if event.data:
|
||||||
self.hass.add_job(
|
self.config_entry.async_create_background_task(
|
||||||
|
self.hass,
|
||||||
self._client.write_event(
|
self._client.write_event(
|
||||||
VoiceStarted(timestamp=event.data["timestamp"]).event()
|
VoiceStarted(timestamp=event.data["timestamp"]).event()
|
||||||
)
|
),
|
||||||
|
f"{self.entity_id} {event.type}",
|
||||||
)
|
)
|
||||||
elif event.type == assist_pipeline.PipelineEventType.STT_VAD_END:
|
elif event.type == assist_pipeline.PipelineEventType.STT_VAD_END:
|
||||||
# User stopped speaking
|
# User stopped speaking
|
||||||
if event.data:
|
if event.data:
|
||||||
self.hass.add_job(
|
self.config_entry.async_create_background_task(
|
||||||
|
self.hass,
|
||||||
self._client.write_event(
|
self._client.write_event(
|
||||||
VoiceStopped(timestamp=event.data["timestamp"]).event()
|
VoiceStopped(timestamp=event.data["timestamp"]).event()
|
||||||
)
|
),
|
||||||
|
f"{self.entity_id} {event.type}",
|
||||||
)
|
)
|
||||||
elif event.type == assist_pipeline.PipelineEventType.STT_END:
|
elif event.type == assist_pipeline.PipelineEventType.STT_END:
|
||||||
# Speech-to-text transcript
|
# Speech-to-text transcript
|
||||||
if event.data:
|
if event.data:
|
||||||
# Inform client of transript
|
# Inform client of transript
|
||||||
stt_text = event.data["stt_output"]["text"]
|
stt_text = event.data["stt_output"]["text"]
|
||||||
self.hass.add_job(
|
self.config_entry.async_create_background_task(
|
||||||
self._client.write_event(Transcript(text=stt_text).event())
|
self.hass,
|
||||||
|
self._client.write_event(Transcript(text=stt_text).event()),
|
||||||
|
f"{self.entity_id} {event.type}",
|
||||||
)
|
)
|
||||||
elif event.type == assist_pipeline.PipelineEventType.TTS_START:
|
elif event.type == assist_pipeline.PipelineEventType.TTS_START:
|
||||||
# Text-to-speech text
|
# Text-to-speech text
|
||||||
if event.data:
|
if event.data:
|
||||||
# Inform client of text
|
# Inform client of text
|
||||||
self.hass.add_job(
|
self.config_entry.async_create_background_task(
|
||||||
|
self.hass,
|
||||||
self._client.write_event(
|
self._client.write_event(
|
||||||
Synthesize(
|
Synthesize(
|
||||||
text=event.data["tts_input"],
|
text=event.data["tts_input"],
|
||||||
@ -235,22 +252,32 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
|
|||||||
language=event.data.get("language"),
|
language=event.data.get("language"),
|
||||||
),
|
),
|
||||||
).event()
|
).event()
|
||||||
)
|
),
|
||||||
|
f"{self.entity_id} {event.type}",
|
||||||
)
|
)
|
||||||
elif event.type == assist_pipeline.PipelineEventType.TTS_END:
|
elif event.type == assist_pipeline.PipelineEventType.TTS_END:
|
||||||
# TTS stream
|
# TTS stream
|
||||||
if event.data and (tts_output := event.data["tts_output"]):
|
if (
|
||||||
media_id = tts_output["media_id"]
|
event.data
|
||||||
self.hass.add_job(self._stream_tts(media_id))
|
and (tts_output := event.data["tts_output"])
|
||||||
|
and (stream := tts.async_get_stream(self.hass, tts_output["token"]))
|
||||||
|
):
|
||||||
|
self.config_entry.async_create_background_task(
|
||||||
|
self.hass,
|
||||||
|
self._stream_tts(stream),
|
||||||
|
f"{self.entity_id} {event.type}",
|
||||||
|
)
|
||||||
elif event.type == assist_pipeline.PipelineEventType.ERROR:
|
elif event.type == assist_pipeline.PipelineEventType.ERROR:
|
||||||
# Pipeline error
|
# Pipeline error
|
||||||
if event.data:
|
if event.data:
|
||||||
self.hass.add_job(
|
self.config_entry.async_create_background_task(
|
||||||
|
self.hass,
|
||||||
self._client.write_event(
|
self._client.write_event(
|
||||||
Error(
|
Error(
|
||||||
text=event.data["message"], code=event.data["code"]
|
text=event.data["message"], code=event.data["code"]
|
||||||
).event()
|
).event()
|
||||||
)
|
),
|
||||||
|
f"{self.entity_id} {event.type}",
|
||||||
)
|
)
|
||||||
|
|
||||||
async def async_announce(self, announcement: AssistSatelliteAnnouncement) -> None:
|
async def async_announce(self, announcement: AssistSatelliteAnnouncement) -> None:
|
||||||
@ -662,13 +689,16 @@ class WyomingAssistSatellite(WyomingSatelliteEntity, AssistSatelliteEntity):
|
|||||||
await self._client.disconnect()
|
await self._client.disconnect()
|
||||||
self._client = None
|
self._client = None
|
||||||
|
|
||||||
async def _stream_tts(self, media_id: str) -> None:
|
async def _stream_tts(self, tts_result: tts.ResultStream) -> None:
|
||||||
"""Stream TTS WAV audio to satellite in chunks."""
|
"""Stream TTS WAV audio to satellite in chunks."""
|
||||||
assert self._client is not None
|
assert self._client is not None
|
||||||
|
|
||||||
extension, data = await tts.async_get_media_source_audio(self.hass, media_id)
|
if tts_result.extension != "wav":
|
||||||
if extension != "wav":
|
raise ValueError(
|
||||||
raise ValueError(f"Cannot stream audio format to satellite: {extension}")
|
f"Cannot stream audio format to satellite: {tts_result.extension}"
|
||||||
|
)
|
||||||
|
|
||||||
|
data = b"".join([chunk async for chunk in tts_result.async_stream_result()])
|
||||||
|
|
||||||
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
|
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
|
||||||
sample_rate = wav_file.getframerate()
|
sample_rate = wav_file.getframerate()
|
||||||
|
@ -35,6 +35,7 @@ from homeassistant.setup import async_setup_component
|
|||||||
from . import SATELLITE_INFO, WAKE_WORD_INFO, MockAsyncTcpClient
|
from . import SATELLITE_INFO, WAKE_WORD_INFO, MockAsyncTcpClient
|
||||||
|
|
||||||
from tests.common import MockConfigEntry
|
from tests.common import MockConfigEntry
|
||||||
|
from tests.components.tts.common import MockResultStream
|
||||||
|
|
||||||
|
|
||||||
async def setup_config_entry(hass: HomeAssistant) -> MockConfigEntry:
|
async def setup_config_entry(hass: HomeAssistant) -> MockConfigEntry:
|
||||||
@ -259,10 +260,6 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None:
|
|||||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||||
async_pipeline_from_audio_stream,
|
async_pipeline_from_audio_stream,
|
||||||
),
|
),
|
||||||
patch(
|
|
||||||
"homeassistant.components.wyoming.assist_satellite.tts.async_get_media_source_audio",
|
|
||||||
return_value=("wav", get_test_wav()),
|
|
||||||
),
|
|
||||||
patch("homeassistant.components.wyoming.assist_satellite._PING_SEND_DELAY", 0),
|
patch("homeassistant.components.wyoming.assist_satellite._PING_SEND_DELAY", 0),
|
||||||
):
|
):
|
||||||
entry = await setup_config_entry(hass)
|
entry = await setup_config_entry(hass)
|
||||||
@ -411,10 +408,11 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None:
|
|||||||
assert mock_client.synthesize.voice.name == "test voice"
|
assert mock_client.synthesize.voice.name == "test voice"
|
||||||
|
|
||||||
# Text-to-speech media
|
# Text-to-speech media
|
||||||
|
mock_tts_result_stream = MockResultStream(hass, "wav", get_test_wav())
|
||||||
pipeline_event_callback(
|
pipeline_event_callback(
|
||||||
assist_pipeline.PipelineEvent(
|
assist_pipeline.PipelineEvent(
|
||||||
assist_pipeline.PipelineEventType.TTS_END,
|
assist_pipeline.PipelineEventType.TTS_END,
|
||||||
{"tts_output": {"media_id": "test media id"}},
|
{"tts_output": {"token": mock_tts_result_stream.token}},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
async with asyncio.timeout(1):
|
async with asyncio.timeout(1):
|
||||||
@ -435,12 +433,6 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None:
|
|||||||
)
|
)
|
||||||
assert not device.is_active
|
assert not device.is_active
|
||||||
|
|
||||||
# The client should have received another ping by now
|
|
||||||
async with asyncio.timeout(1):
|
|
||||||
await mock_client.ping_event.wait()
|
|
||||||
|
|
||||||
assert mock_client.ping is not None
|
|
||||||
|
|
||||||
# Pipeline should automatically restart
|
# Pipeline should automatically restart
|
||||||
async with asyncio.timeout(1):
|
async with asyncio.timeout(1):
|
||||||
await run_pipeline_called.wait()
|
await run_pipeline_called.wait()
|
||||||
@ -746,10 +738,6 @@ async def test_tts_not_wav(hass: HomeAssistant) -> None:
|
|||||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||||
wraps=_async_pipeline_from_audio_stream,
|
wraps=_async_pipeline_from_audio_stream,
|
||||||
) as mock_run_pipeline,
|
) as mock_run_pipeline,
|
||||||
patch(
|
|
||||||
"homeassistant.components.wyoming.assist_satellite.tts.async_get_media_source_audio",
|
|
||||||
return_value=("mp3", bytes(1)),
|
|
||||||
),
|
|
||||||
patch(
|
patch(
|
||||||
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite._stream_tts",
|
"homeassistant.components.wyoming.assist_satellite.WyomingAssistSatellite._stream_tts",
|
||||||
_stream_tts,
|
_stream_tts,
|
||||||
@ -779,10 +767,11 @@ async def test_tts_not_wav(hass: HomeAssistant) -> None:
|
|||||||
await mock_client.synthesize_event.wait()
|
await mock_client.synthesize_event.wait()
|
||||||
|
|
||||||
# Text-to-speech media
|
# Text-to-speech media
|
||||||
|
mock_tts_result_stream = MockResultStream(hass, "mp3", bytes(1))
|
||||||
event_callback(
|
event_callback(
|
||||||
assist_pipeline.PipelineEvent(
|
assist_pipeline.PipelineEvent(
|
||||||
assist_pipeline.PipelineEventType.TTS_END,
|
assist_pipeline.PipelineEventType.TTS_END,
|
||||||
{"tts_output": {"media_id": "test media id"}},
|
{"tts_output": {"token": mock_tts_result_stream.token}},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user