Migrate ESPHome to use token instead of media source ID for legacy Assist Pipelines (#139665)

Migrate legacy ESPHome devices to use TTS token

Co-authored-by: Franck Nijhof <git@frenck.dev>
This commit is contained in:
Paulus Schoutsen 2025-03-30 10:53:49 -04:00 committed by GitHub
parent ec20e41836
commit 86be626c69
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 75 additions and 106 deletions

View File

@ -310,12 +310,13 @@ class EsphomeAssistSatellite(
self.entry_data.api_version
)
)
if feature_flags & VoiceAssistantFeature.SPEAKER:
media_id = tts_output["media_id"]
if feature_flags & VoiceAssistantFeature.SPEAKER and (
stream := tts.async_get_stream(self.hass, tts_output["token"])
):
self._tts_streaming_task = (
self.config_entry.async_create_background_task(
self.hass,
self._stream_tts_audio(media_id),
self._stream_tts_audio(stream),
"esphome_voice_assistant_tts",
)
)
@ -564,7 +565,7 @@ class EsphomeAssistSatellite(
async def _stream_tts_audio(
self,
media_id: str,
tts_result: tts.ResultStream,
sample_rate: int = 16000,
sample_width: int = 2,
sample_channels: int = 1,
@ -579,15 +580,14 @@ class EsphomeAssistSatellite(
if not self._is_running:
return
extension, data = await tts.async_get_media_source_audio(
self.hass,
media_id,
)
if extension != "wav":
_LOGGER.error("Only WAV audio can be streamed, got %s", extension)
if tts_result.extension != "wav":
_LOGGER.error(
"Only WAV audio can be streamed, got %s", tts_result.extension
)
return
data = b"".join([chunk async for chunk in tts_result.async_stream_result()])
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
if (
(wav_file.getframerate() != sample_rate)

View File

@ -58,6 +58,7 @@ from homeassistant.helpers import (
intent as intent_helper,
)
from homeassistant.helpers.entity_component import EntityComponent
from homeassistant.helpers.network import get_url
from .conftest import MockESPHomeDevice
@ -133,8 +134,6 @@ async def test_pipeline_api_audio(
) -> None:
"""Test a complete pipeline run with API audio (over the TCP connection)."""
conversation_id = "test-conversation-id"
media_url = "http://test.url"
media_id = "test-media-id"
mock_device: MockESPHomeDevice = await mock_esphome_device(
mock_client=mock_client,
@ -328,15 +327,22 @@ async def test_pipeline_api_audio(
assert satellite.state == AssistSatelliteState.RESPONDING
# Should return mock_wav audio
mock_tts_result_stream = MockResultStream(hass, "wav", mock_wav)
event_callback(
PipelineEvent(
type=PipelineEventType.TTS_END,
data={"tts_output": {"url": media_url, "media_id": media_id}},
data={
"tts_output": {
"media_id": "test-media-id",
"url": mock_tts_result_stream.url,
"token": mock_tts_result_stream.token,
}
},
)
)
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END,
{"url": media_url},
{"url": get_url(hass) + mock_tts_result_stream.url},
)
event_callback(PipelineEvent(type=PipelineEventType.RUN_END))
@ -355,12 +361,6 @@ async def test_pipeline_api_audio(
original_handle_pipeline_finished()
pipeline_finished.set()
async def async_get_media_source_audio(
hass: HomeAssistant,
media_source_id: str,
) -> tuple[str, bytes]:
return ("wav", mock_wav)
tts_finished = asyncio.Event()
original_tts_response_finished = satellite.tts_response_finished
@ -373,10 +373,6 @@ async def test_pipeline_api_audio(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.tts.async_get_media_source_audio",
new=async_get_media_source_audio,
),
patch.object(satellite, "handle_pipeline_finished", handle_pipeline_finished),
patch.object(satellite, "_stream_tts_audio", _stream_tts_audio),
patch.object(satellite, "tts_response_finished", tts_response_finished),
@ -434,8 +430,6 @@ async def test_pipeline_udp_audio(
mainly focused on the UDP server.
"""
conversation_id = "test-conversation-id"
media_url = "http://test.url"
media_id = "test-media-id"
mock_device: MockESPHomeDevice = await mock_esphome_device(
mock_client=mock_client,
@ -522,10 +516,17 @@ async def test_pipeline_udp_audio(
)
# Should return mock_wav audio
mock_tts_result_stream = MockResultStream(hass, "wav", mock_wav)
event_callback(
PipelineEvent(
type=PipelineEventType.TTS_END,
data={"tts_output": {"url": media_url, "media_id": media_id}},
data={
"tts_output": {
"media_id": "test-media-id",
"url": mock_tts_result_stream.url,
"token": mock_tts_result_stream.token,
}
},
)
)
@ -538,12 +539,6 @@ async def test_pipeline_udp_audio(
original_handle_pipeline_finished()
pipeline_finished.set()
async def async_get_media_source_audio(
hass: HomeAssistant,
media_source_id: str,
) -> tuple[str, bytes]:
return ("wav", mock_wav)
tts_finished = asyncio.Event()
original_tts_response_finished = satellite.tts_response_finished
@ -567,10 +562,6 @@ async def test_pipeline_udp_audio(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.tts.async_get_media_source_audio",
new=async_get_media_source_audio,
),
patch.object(satellite, "handle_pipeline_finished", handle_pipeline_finished),
patch.object(satellite, "tts_response_finished", tts_response_finished),
):
@ -652,8 +643,6 @@ async def test_pipeline_media_player(
mainly focused on tts_response_finished getting automatically called.
"""
conversation_id = "test-conversation-id"
media_url = "http://test.url"
media_id = "test-media-id"
mock_device: MockESPHomeDevice = await mock_esphome_device(
mock_client=mock_client,
@ -733,10 +722,17 @@ async def test_pipeline_media_player(
)
# Should return mock_wav audio
mock_tts_result_stream = MockResultStream(hass, "wav", mock_wav)
event_callback(
PipelineEvent(
type=PipelineEventType.TTS_END,
data={"tts_output": {"url": media_url, "media_id": media_id}},
data={
"tts_output": {
"media_id": "test-media-id",
"url": mock_tts_result_stream.url,
"token": mock_tts_result_stream.token,
}
},
)
)
@ -749,12 +745,6 @@ async def test_pipeline_media_player(
original_handle_pipeline_finished()
pipeline_finished.set()
async def async_get_media_source_audio(
hass: HomeAssistant,
media_source_id: str,
) -> tuple[str, bytes]:
return ("wav", mock_wav)
tts_finished = asyncio.Event()
original_tts_response_finished = satellite.tts_response_finished
@ -767,10 +757,6 @@ async def test_pipeline_media_player(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.tts.async_get_media_source_audio",
new=async_get_media_source_audio,
),
patch.object(satellite, "handle_pipeline_finished", handle_pipeline_finished),
patch.object(satellite, "tts_response_finished", tts_response_finished),
):
@ -944,80 +930,63 @@ async def test_streaming_tts_errors(
# Should not stream if not running
satellite._is_running = False
await satellite._stream_tts_audio("test-media-id")
await satellite._stream_tts_audio(MockResultStream(hass, "wav", mock_wav))
mock_client.send_voice_assistant_audio.assert_not_called()
satellite._is_running = True
# Should only stream WAV
async def get_mp3(
hass: HomeAssistant,
media_source_id: str,
) -> tuple[str, bytes]:
return ("mp3", b"")
with patch(
"homeassistant.components.tts.async_get_media_source_audio", new=get_mp3
):
await satellite._stream_tts_audio("test-media-id")
mock_client.send_voice_assistant_audio.assert_not_called()
await satellite._stream_tts_audio(MockResultStream(hass, "mp3", b""))
mock_client.send_voice_assistant_audio.assert_not_called()
# Needs to be the correct sample rate, etc.
async def get_bad_wav(
hass: HomeAssistant,
media_source_id: str,
) -> tuple[str, bytes]:
with io.BytesIO() as wav_io:
with wave.open(wav_io, "wb") as wav_file:
wav_file.setframerate(48000)
wav_file.setsampwidth(2)
wav_file.setnchannels(1)
wav_file.writeframes(b"test-wav")
with io.BytesIO() as wav_io:
with wave.open(wav_io, "wb") as wav_file:
wav_file.setframerate(48000)
wav_file.setsampwidth(2)
wav_file.setnchannels(1)
wav_file.writeframes(b"test-wav")
return ("wav", wav_io.getvalue())
mock_tts_result_stream = MockResultStream(hass, "wav", wav_io.getvalue())
with patch(
"homeassistant.components.tts.async_get_media_source_audio", new=get_bad_wav
):
await satellite._stream_tts_audio("test-media-id")
mock_client.send_voice_assistant_audio.assert_not_called()
await satellite._stream_tts_audio(mock_tts_result_stream)
mock_client.send_voice_assistant_audio.assert_not_called()
# Check that TTS_STREAM_* events still get sent after cancel
media_fetched = asyncio.Event()
async def get_slow_wav(
hass: HomeAssistant,
media_source_id: str,
) -> tuple[str, bytes]:
mock_tts_result_stream = MockResultStream(hass, "wav", b"")
async def async_stream_result_slowly():
media_fetched.set()
await asyncio.sleep(1)
return ("wav", mock_wav)
yield mock_wav
mock_tts_result_stream.async_stream_result = async_stream_result_slowly
mock_client.send_voice_assistant_event.reset_mock()
with patch(
"homeassistant.components.tts.async_get_media_source_audio", new=get_slow_wav
):
task = asyncio.create_task(satellite._stream_tts_audio("test-media-id"))
async with asyncio.timeout(1):
# Wait for media to be fetched
await media_fetched.wait()
# Cancel task
task.cancel()
await task
task = asyncio.create_task(satellite._stream_tts_audio(mock_tts_result_stream))
async with asyncio.timeout(1):
# Wait for media to be fetched
await media_fetched.wait()
# No audio should have gone out
mock_client.send_voice_assistant_audio.assert_not_called()
assert len(mock_client.send_voice_assistant_event.call_args_list) == 2
# Cancel task
task.cancel()
await task
# The TTS_STREAM_* events should have gone out
assert mock_client.send_voice_assistant_event.call_args_list[-2].args == (
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START,
{},
)
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END,
{},
)
# No audio should have gone out
mock_client.send_voice_assistant_audio.assert_not_called()
assert len(mock_client.send_voice_assistant_event.call_args_list) == 2
# The TTS_STREAM_* events should have gone out
assert mock_client.send_voice_assistant_event.call_args_list[-2].args == (
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START,
{},
)
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END,
{},
)
async def test_tts_format_from_media_player(