Migrate VoIP to use Assist Pipeline TTS tokens (#139671)

* Migrate VoIP to use pipeline token

* migrate announcements to use TTS token
This commit is contained in:
Paulus Schoutsen 2025-04-22 10:24:24 -04:00 committed by GitHub
parent 871a7c87bf
commit 8aa30b0ccb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 64 additions and 92 deletions

View File

@ -408,10 +408,18 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
"""Play an announcement once.""" """Play an announcement once."""
_LOGGER.debug("Playing announcement") _LOGGER.debug("Playing announcement")
try: if announcement.tts_token is None:
await asyncio.sleep(_ANNOUNCEMENT_BEFORE_DELAY) _LOGGER.error("Only TTS announcements are supported")
await self._send_tts(announcement.original_media_id, wait_for_tone=False) return
await asyncio.sleep(_ANNOUNCEMENT_BEFORE_DELAY)
stream = tts.async_get_stream(self.hass, announcement.tts_token)
if stream is None:
_LOGGER.error("TTS stream no longer available")
return
try:
await self._send_tts(stream, wait_for_tone=False)
if not self._run_pipeline_after_announce: if not self._run_pipeline_after_announce:
# Delay before looping announcement # Delay before looping announcement
await asyncio.sleep(_ANNOUNCEMENT_AFTER_DELAY) await asyncio.sleep(_ANNOUNCEMENT_AFTER_DELAY)
@ -442,11 +450,14 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
) )
elif event.type == PipelineEventType.TTS_END: elif event.type == PipelineEventType.TTS_END:
# Send TTS audio to caller over RTP # Send TTS audio to caller over RTP
if event.data and (tts_output := event.data["tts_output"]): if (
media_id = tts_output["media_id"] event.data
and (tts_output := event.data["tts_output"])
and (stream := tts.async_get_stream(self.hass, tts_output["token"]))
):
self.config_entry.async_create_background_task( self.config_entry.async_create_background_task(
self.hass, self.hass,
self._send_tts(media_id), self._send_tts(tts_stream=stream),
"voip_pipeline_tts", "voip_pipeline_tts",
) )
else: else:
@ -457,19 +468,22 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
self._pipeline_had_error = True self._pipeline_had_error = True
_LOGGER.warning(event) _LOGGER.warning(event)
async def _send_tts(self, media_id: str, wait_for_tone: bool = True) -> None: async def _send_tts(
self,
tts_stream: tts.ResultStream,
wait_for_tone: bool = True,
) -> None:
"""Send TTS audio to caller via RTP.""" """Send TTS audio to caller via RTP."""
try: try:
if self.transport is None: if self.transport is None:
return # not connected return # not connected
extension, data = await tts.async_get_media_source_audio( data = b"".join([chunk async for chunk in tts_stream.async_stream_result()])
self.hass,
media_id,
)
if extension != "wav": if tts_stream.extension != "wav":
raise ValueError(f"Only WAV audio can be streamed, got {extension}") raise ValueError(
f"Only TTS WAV audio can be streamed, got {tts_stream.extension}"
)
if wait_for_tone and ((self._tones & Tones.PROCESSING) == Tones.PROCESSING): if wait_for_tone and ((self._tones & Tones.PROCESSING) == Tones.PROCESSING):
# Don't overlap TTS and processing beep # Don't overlap TTS and processing beep

View File

@ -38,12 +38,12 @@ def mock_tts_cache_dir_autouse(mock_tts_cache_dir: Path) -> None:
"""Mock the TTS cache dir with empty dir.""" """Mock the TTS cache dir with empty dir."""
def _empty_wav() -> bytes: def _empty_wav(framerate=16000) -> bytes:
"""Return bytes of an empty WAV file.""" """Return bytes of an empty WAV file."""
with io.BytesIO() as wav_io: with io.BytesIO() as wav_io:
wav_file: wave.Wave_write = wave.open(wav_io, "wb") wav_file: wave.Wave_write = wave.open(wav_io, "wb")
with wav_file: with wav_file:
wav_file.setframerate(16000) wav_file.setframerate(framerate)
wav_file.setsampwidth(2) wav_file.setsampwidth(2)
wav_file.setnchannels(1) wav_file.setnchannels(1)
@ -307,10 +307,11 @@ async def test_pipeline(
assert satellite.state == AssistSatelliteState.RESPONDING assert satellite.state == AssistSatelliteState.RESPONDING
# Proceed with media output # Proceed with media output
mock_tts_result_stream = MockResultStream(hass, "wav", _empty_wav())
event_callback( event_callback(
assist_pipeline.PipelineEvent( assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.TTS_END, type=assist_pipeline.PipelineEventType.TTS_END,
data={"tts_output": {"media_id": _MEDIA_ID}}, data={"tts_output": {"token": mock_tts_result_stream.token}},
) )
) )
@ -326,22 +327,11 @@ async def test_pipeline(
original_tts_response_finished() original_tts_response_finished()
done.set() done.set()
async def async_get_media_source_audio(
hass: HomeAssistant,
media_source_id: str,
) -> tuple[str, bytes]:
assert media_source_id == _MEDIA_ID
return ("wav", _empty_wav())
with ( with (
patch( patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream", "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream, new=async_pipeline_from_audio_stream,
), ),
patch(
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
new=async_get_media_source_audio,
),
patch.object(satellite, "tts_response_finished", tts_response_finished), patch.object(satellite, "tts_response_finished", tts_response_finished),
): ):
satellite._tones = Tones(0) satellite._tones = Tones(0)
@ -457,10 +447,11 @@ async def test_tts_timeout(
) )
# Proceed with media output # Proceed with media output
mock_tts_result_stream = MockResultStream(hass, "wav", _empty_wav())
event_callback( event_callback(
assist_pipeline.PipelineEvent( assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.TTS_END, type=assist_pipeline.PipelineEventType.TTS_END,
data={"tts_output": {"media_id": _MEDIA_ID}}, data={"tts_output": {"token": mock_tts_result_stream.token}},
) )
) )
@ -474,22 +465,9 @@ async def test_tts_timeout(
# Block here to force a timeout in _send_tts # Block here to force a timeout in _send_tts
await asyncio.sleep(2) await asyncio.sleep(2)
async def async_get_media_source_audio( with patch(
hass: HomeAssistant, "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
media_source_id: str, new=async_pipeline_from_audio_stream,
) -> tuple[str, bytes]:
# Should time out immediately
return ("wav", _empty_wav())
with (
patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
new=async_get_media_source_audio,
),
): ):
satellite._tts_extra_timeout = 0.001 satellite._tts_extra_timeout = 0.001
for tone in Tones: for tone in Tones:
@ -568,29 +546,18 @@ async def test_tts_wrong_extension(
) )
# Proceed with media output # Proceed with media output
# Should fail because it's not "wav"
mock_tts_result_stream = MockResultStream(hass, "mp3", b"")
event_callback( event_callback(
assist_pipeline.PipelineEvent( assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.TTS_END, type=assist_pipeline.PipelineEventType.TTS_END,
data={"tts_output": {"media_id": _MEDIA_ID}}, data={"tts_output": {"token": mock_tts_result_stream.token}},
) )
) )
async def async_get_media_source_audio( with patch(
hass: HomeAssistant, "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
media_source_id: str, new=async_pipeline_from_audio_stream,
) -> tuple[str, bytes]:
# Should fail because it's not "wav"
return ("mp3", b"")
with (
patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
new=async_get_media_source_audio,
),
): ):
satellite.transport = Mock() satellite.transport = Mock()
@ -663,36 +630,18 @@ async def test_tts_wrong_wav_format(
) )
# Proceed with media output # Proceed with media output
# Should fail because it's not 16Khz
mock_tts_result_stream = MockResultStream(hass, "wav", _empty_wav(22050))
event_callback( event_callback(
assist_pipeline.PipelineEvent( assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.TTS_END, type=assist_pipeline.PipelineEventType.TTS_END,
data={"tts_output": {"media_id": _MEDIA_ID}}, data={"tts_output": {"token": mock_tts_result_stream.token}},
) )
) )
async def async_get_media_source_audio( with patch(
hass: HomeAssistant, "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
media_source_id: str, new=async_pipeline_from_audio_stream,
) -> tuple[str, bytes]:
# Should fail because it's not 16Khz, 16-bit mono
with io.BytesIO() as wav_io:
wav_file: wave.Wave_write = wave.open(wav_io, "wb")
with wav_file:
wav_file.setframerate(22050)
wav_file.setsampwidth(2)
wav_file.setnchannels(2)
return ("wav", wav_io.getvalue())
with (
patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
new=async_get_media_source_audio,
),
): ):
satellite.transport = Mock() satellite.transport = Mock()
@ -878,10 +827,11 @@ async def test_announce(
assert err.value.translation_domain == "voip" assert err.value.translation_domain == "voip"
assert err.value.translation_key == "non_tts_announcement" assert err.value.translation_key == "non_tts_announcement"
mock_tts_result_stream = MockResultStream(hass, "wav", _empty_wav())
announcement = assist_satellite.AssistSatelliteAnnouncement( announcement = assist_satellite.AssistSatelliteAnnouncement(
message="test announcement", message="test announcement",
media_id=_MEDIA_ID, media_id=_MEDIA_ID,
tts_token="test-token", tts_token=mock_tts_result_stream.token,
original_media_id=_MEDIA_ID, original_media_id=_MEDIA_ID,
media_id_source="tts", media_id_source="tts",
) )
@ -907,7 +857,9 @@ async def test_announce(
async with asyncio.timeout(1): async with asyncio.timeout(1):
await announce_task await announce_task
mock_send_tts.assert_called_once_with(_MEDIA_ID, wait_for_tone=False) mock_send_tts.assert_called_once_with(
mock_tts_result_stream, wait_for_tone=False
)
@pytest.mark.usefixtures("socket_enabled") @pytest.mark.usefixtures("socket_enabled")
@ -926,10 +878,11 @@ async def test_voip_id_is_ip_address(
& assist_satellite.AssistSatelliteEntityFeature.ANNOUNCE & assist_satellite.AssistSatelliteEntityFeature.ANNOUNCE
) )
mock_tts_result_stream = MockResultStream(hass, "wav", _empty_wav())
announcement = assist_satellite.AssistSatelliteAnnouncement( announcement = assist_satellite.AssistSatelliteAnnouncement(
message="test announcement", message="test announcement",
media_id=_MEDIA_ID, media_id=_MEDIA_ID,
tts_token="test-token", tts_token=mock_tts_result_stream.token,
original_media_id=_MEDIA_ID, original_media_id=_MEDIA_ID,
media_id_source="tts", media_id_source="tts",
) )
@ -960,7 +913,9 @@ async def test_voip_id_is_ip_address(
async with asyncio.timeout(1): async with asyncio.timeout(1):
await announce_task await announce_task
mock_send_tts.assert_called_once_with(_MEDIA_ID, wait_for_tone=False) mock_send_tts.assert_called_once_with(
mock_tts_result_stream, wait_for_tone=False
)
@pytest.mark.usefixtures("socket_enabled") @pytest.mark.usefixtures("socket_enabled")
@ -979,10 +934,11 @@ async def test_announce_timeout(
& assist_satellite.AssistSatelliteEntityFeature.ANNOUNCE & assist_satellite.AssistSatelliteEntityFeature.ANNOUNCE
) )
mock_tts_result_stream = MockResultStream(hass, "wav", _empty_wav())
announcement = assist_satellite.AssistSatelliteAnnouncement( announcement = assist_satellite.AssistSatelliteAnnouncement(
message="test announcement", message="test announcement",
media_id=_MEDIA_ID, media_id=_MEDIA_ID,
tts_token="test-token", tts_token=mock_tts_result_stream.token,
original_media_id=_MEDIA_ID, original_media_id=_MEDIA_ID,
media_id_source="tts", media_id_source="tts",
) )
@ -1020,10 +976,11 @@ async def test_start_conversation(
& assist_satellite.AssistSatelliteEntityFeature.START_CONVERSATION & assist_satellite.AssistSatelliteEntityFeature.START_CONVERSATION
) )
mock_tts_result_stream = MockResultStream(hass, "wav", _empty_wav())
announcement = assist_satellite.AssistSatelliteAnnouncement( announcement = assist_satellite.AssistSatelliteAnnouncement(
message="test announcement", message="test announcement",
media_id=_MEDIA_ID, media_id=_MEDIA_ID,
tts_token="test-token", tts_token=mock_tts_result_stream.token,
original_media_id=_MEDIA_ID, original_media_id=_MEDIA_ID,
media_id_source="tts", media_id_source="tts",
) )
@ -1061,10 +1018,11 @@ async def test_start_conversation(
) )
# Proceed with media output # Proceed with media output
mock_tts_result_stream = MockResultStream(hass, "wav", _empty_wav())
event_callback( event_callback(
assist_pipeline.PipelineEvent( assist_pipeline.PipelineEvent(
type=assist_pipeline.PipelineEventType.TTS_END, type=assist_pipeline.PipelineEventType.TTS_END,
data={"tts_output": {"media_id": _MEDIA_ID}}, data={"tts_output": {"token": mock_tts_result_stream.token}},
) )
) )