mirror of
https://github.com/home-assistant/core.git
synced 2025-04-24 01:08:12 +00:00
Migrate ESPHome to use token instead of media source ID for legacy Assist Pipelines (#139665)
Migrate legacy ESPHome devices to use TTS token Co-authored-by: Franck Nijhof <git@frenck.dev>
This commit is contained in:
parent
ec20e41836
commit
86be626c69
@ -310,12 +310,13 @@ class EsphomeAssistSatellite(
|
||||
self.entry_data.api_version
|
||||
)
|
||||
)
|
||||
if feature_flags & VoiceAssistantFeature.SPEAKER:
|
||||
media_id = tts_output["media_id"]
|
||||
if feature_flags & VoiceAssistantFeature.SPEAKER and (
|
||||
stream := tts.async_get_stream(self.hass, tts_output["token"])
|
||||
):
|
||||
self._tts_streaming_task = (
|
||||
self.config_entry.async_create_background_task(
|
||||
self.hass,
|
||||
self._stream_tts_audio(media_id),
|
||||
self._stream_tts_audio(stream),
|
||||
"esphome_voice_assistant_tts",
|
||||
)
|
||||
)
|
||||
@ -564,7 +565,7 @@ class EsphomeAssistSatellite(
|
||||
|
||||
async def _stream_tts_audio(
|
||||
self,
|
||||
media_id: str,
|
||||
tts_result: tts.ResultStream,
|
||||
sample_rate: int = 16000,
|
||||
sample_width: int = 2,
|
||||
sample_channels: int = 1,
|
||||
@ -579,15 +580,14 @@ class EsphomeAssistSatellite(
|
||||
if not self._is_running:
|
||||
return
|
||||
|
||||
extension, data = await tts.async_get_media_source_audio(
|
||||
self.hass,
|
||||
media_id,
|
||||
)
|
||||
|
||||
if extension != "wav":
|
||||
_LOGGER.error("Only WAV audio can be streamed, got %s", extension)
|
||||
if tts_result.extension != "wav":
|
||||
_LOGGER.error(
|
||||
"Only WAV audio can be streamed, got %s", tts_result.extension
|
||||
)
|
||||
return
|
||||
|
||||
data = b"".join([chunk async for chunk in tts_result.async_stream_result()])
|
||||
|
||||
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
|
||||
if (
|
||||
(wav_file.getframerate() != sample_rate)
|
||||
|
@ -58,6 +58,7 @@ from homeassistant.helpers import (
|
||||
intent as intent_helper,
|
||||
)
|
||||
from homeassistant.helpers.entity_component import EntityComponent
|
||||
from homeassistant.helpers.network import get_url
|
||||
|
||||
from .conftest import MockESPHomeDevice
|
||||
|
||||
@ -133,8 +134,6 @@ async def test_pipeline_api_audio(
|
||||
) -> None:
|
||||
"""Test a complete pipeline run with API audio (over the TCP connection)."""
|
||||
conversation_id = "test-conversation-id"
|
||||
media_url = "http://test.url"
|
||||
media_id = "test-media-id"
|
||||
|
||||
mock_device: MockESPHomeDevice = await mock_esphome_device(
|
||||
mock_client=mock_client,
|
||||
@ -328,15 +327,22 @@ async def test_pipeline_api_audio(
|
||||
assert satellite.state == AssistSatelliteState.RESPONDING
|
||||
|
||||
# Should return mock_wav audio
|
||||
mock_tts_result_stream = MockResultStream(hass, "wav", mock_wav)
|
||||
event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.TTS_END,
|
||||
data={"tts_output": {"url": media_url, "media_id": media_id}},
|
||||
data={
|
||||
"tts_output": {
|
||||
"media_id": "test-media-id",
|
||||
"url": mock_tts_result_stream.url,
|
||||
"token": mock_tts_result_stream.token,
|
||||
}
|
||||
},
|
||||
)
|
||||
)
|
||||
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END,
|
||||
{"url": media_url},
|
||||
{"url": get_url(hass) + mock_tts_result_stream.url},
|
||||
)
|
||||
|
||||
event_callback(PipelineEvent(type=PipelineEventType.RUN_END))
|
||||
@ -355,12 +361,6 @@ async def test_pipeline_api_audio(
|
||||
original_handle_pipeline_finished()
|
||||
pipeline_finished.set()
|
||||
|
||||
async def async_get_media_source_audio(
|
||||
hass: HomeAssistant,
|
||||
media_source_id: str,
|
||||
) -> tuple[str, bytes]:
|
||||
return ("wav", mock_wav)
|
||||
|
||||
tts_finished = asyncio.Event()
|
||||
original_tts_response_finished = satellite.tts_response_finished
|
||||
|
||||
@ -373,10 +373,6 @@ async def test_pipeline_api_audio(
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.tts.async_get_media_source_audio",
|
||||
new=async_get_media_source_audio,
|
||||
),
|
||||
patch.object(satellite, "handle_pipeline_finished", handle_pipeline_finished),
|
||||
patch.object(satellite, "_stream_tts_audio", _stream_tts_audio),
|
||||
patch.object(satellite, "tts_response_finished", tts_response_finished),
|
||||
@ -434,8 +430,6 @@ async def test_pipeline_udp_audio(
|
||||
mainly focused on the UDP server.
|
||||
"""
|
||||
conversation_id = "test-conversation-id"
|
||||
media_url = "http://test.url"
|
||||
media_id = "test-media-id"
|
||||
|
||||
mock_device: MockESPHomeDevice = await mock_esphome_device(
|
||||
mock_client=mock_client,
|
||||
@ -522,10 +516,17 @@ async def test_pipeline_udp_audio(
|
||||
)
|
||||
|
||||
# Should return mock_wav audio
|
||||
mock_tts_result_stream = MockResultStream(hass, "wav", mock_wav)
|
||||
event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.TTS_END,
|
||||
data={"tts_output": {"url": media_url, "media_id": media_id}},
|
||||
data={
|
||||
"tts_output": {
|
||||
"media_id": "test-media-id",
|
||||
"url": mock_tts_result_stream.url,
|
||||
"token": mock_tts_result_stream.token,
|
||||
}
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
@ -538,12 +539,6 @@ async def test_pipeline_udp_audio(
|
||||
original_handle_pipeline_finished()
|
||||
pipeline_finished.set()
|
||||
|
||||
async def async_get_media_source_audio(
|
||||
hass: HomeAssistant,
|
||||
media_source_id: str,
|
||||
) -> tuple[str, bytes]:
|
||||
return ("wav", mock_wav)
|
||||
|
||||
tts_finished = asyncio.Event()
|
||||
original_tts_response_finished = satellite.tts_response_finished
|
||||
|
||||
@ -567,10 +562,6 @@ async def test_pipeline_udp_audio(
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.tts.async_get_media_source_audio",
|
||||
new=async_get_media_source_audio,
|
||||
),
|
||||
patch.object(satellite, "handle_pipeline_finished", handle_pipeline_finished),
|
||||
patch.object(satellite, "tts_response_finished", tts_response_finished),
|
||||
):
|
||||
@ -652,8 +643,6 @@ async def test_pipeline_media_player(
|
||||
mainly focused on tts_response_finished getting automatically called.
|
||||
"""
|
||||
conversation_id = "test-conversation-id"
|
||||
media_url = "http://test.url"
|
||||
media_id = "test-media-id"
|
||||
|
||||
mock_device: MockESPHomeDevice = await mock_esphome_device(
|
||||
mock_client=mock_client,
|
||||
@ -733,10 +722,17 @@ async def test_pipeline_media_player(
|
||||
)
|
||||
|
||||
# Should return mock_wav audio
|
||||
mock_tts_result_stream = MockResultStream(hass, "wav", mock_wav)
|
||||
event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.TTS_END,
|
||||
data={"tts_output": {"url": media_url, "media_id": media_id}},
|
||||
data={
|
||||
"tts_output": {
|
||||
"media_id": "test-media-id",
|
||||
"url": mock_tts_result_stream.url,
|
||||
"token": mock_tts_result_stream.token,
|
||||
}
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
@ -749,12 +745,6 @@ async def test_pipeline_media_player(
|
||||
original_handle_pipeline_finished()
|
||||
pipeline_finished.set()
|
||||
|
||||
async def async_get_media_source_audio(
|
||||
hass: HomeAssistant,
|
||||
media_source_id: str,
|
||||
) -> tuple[str, bytes]:
|
||||
return ("wav", mock_wav)
|
||||
|
||||
tts_finished = asyncio.Event()
|
||||
original_tts_response_finished = satellite.tts_response_finished
|
||||
|
||||
@ -767,10 +757,6 @@ async def test_pipeline_media_player(
|
||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
),
|
||||
patch(
|
||||
"homeassistant.components.tts.async_get_media_source_audio",
|
||||
new=async_get_media_source_audio,
|
||||
),
|
||||
patch.object(satellite, "handle_pipeline_finished", handle_pipeline_finished),
|
||||
patch.object(satellite, "tts_response_finished", tts_response_finished),
|
||||
):
|
||||
@ -944,80 +930,63 @@ async def test_streaming_tts_errors(
|
||||
|
||||
# Should not stream if not running
|
||||
satellite._is_running = False
|
||||
await satellite._stream_tts_audio("test-media-id")
|
||||
await satellite._stream_tts_audio(MockResultStream(hass, "wav", mock_wav))
|
||||
mock_client.send_voice_assistant_audio.assert_not_called()
|
||||
satellite._is_running = True
|
||||
|
||||
# Should only stream WAV
|
||||
async def get_mp3(
|
||||
hass: HomeAssistant,
|
||||
media_source_id: str,
|
||||
) -> tuple[str, bytes]:
|
||||
return ("mp3", b"")
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.tts.async_get_media_source_audio", new=get_mp3
|
||||
):
|
||||
await satellite._stream_tts_audio("test-media-id")
|
||||
mock_client.send_voice_assistant_audio.assert_not_called()
|
||||
await satellite._stream_tts_audio(MockResultStream(hass, "mp3", b""))
|
||||
mock_client.send_voice_assistant_audio.assert_not_called()
|
||||
|
||||
# Needs to be the correct sample rate, etc.
|
||||
async def get_bad_wav(
|
||||
hass: HomeAssistant,
|
||||
media_source_id: str,
|
||||
) -> tuple[str, bytes]:
|
||||
with io.BytesIO() as wav_io:
|
||||
with wave.open(wav_io, "wb") as wav_file:
|
||||
wav_file.setframerate(48000)
|
||||
wav_file.setsampwidth(2)
|
||||
wav_file.setnchannels(1)
|
||||
wav_file.writeframes(b"test-wav")
|
||||
with io.BytesIO() as wav_io:
|
||||
with wave.open(wav_io, "wb") as wav_file:
|
||||
wav_file.setframerate(48000)
|
||||
wav_file.setsampwidth(2)
|
||||
wav_file.setnchannels(1)
|
||||
wav_file.writeframes(b"test-wav")
|
||||
|
||||
return ("wav", wav_io.getvalue())
|
||||
mock_tts_result_stream = MockResultStream(hass, "wav", wav_io.getvalue())
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.tts.async_get_media_source_audio", new=get_bad_wav
|
||||
):
|
||||
await satellite._stream_tts_audio("test-media-id")
|
||||
mock_client.send_voice_assistant_audio.assert_not_called()
|
||||
await satellite._stream_tts_audio(mock_tts_result_stream)
|
||||
mock_client.send_voice_assistant_audio.assert_not_called()
|
||||
|
||||
# Check that TTS_STREAM_* events still get sent after cancel
|
||||
media_fetched = asyncio.Event()
|
||||
|
||||
async def get_slow_wav(
|
||||
hass: HomeAssistant,
|
||||
media_source_id: str,
|
||||
) -> tuple[str, bytes]:
|
||||
mock_tts_result_stream = MockResultStream(hass, "wav", b"")
|
||||
|
||||
async def async_stream_result_slowly():
|
||||
media_fetched.set()
|
||||
await asyncio.sleep(1)
|
||||
return ("wav", mock_wav)
|
||||
yield mock_wav
|
||||
|
||||
mock_tts_result_stream.async_stream_result = async_stream_result_slowly
|
||||
|
||||
mock_client.send_voice_assistant_event.reset_mock()
|
||||
with patch(
|
||||
"homeassistant.components.tts.async_get_media_source_audio", new=get_slow_wav
|
||||
):
|
||||
task = asyncio.create_task(satellite._stream_tts_audio("test-media-id"))
|
||||
async with asyncio.timeout(1):
|
||||
# Wait for media to be fetched
|
||||
await media_fetched.wait()
|
||||
|
||||
# Cancel task
|
||||
task.cancel()
|
||||
await task
|
||||
task = asyncio.create_task(satellite._stream_tts_audio(mock_tts_result_stream))
|
||||
async with asyncio.timeout(1):
|
||||
# Wait for media to be fetched
|
||||
await media_fetched.wait()
|
||||
|
||||
# No audio should have gone out
|
||||
mock_client.send_voice_assistant_audio.assert_not_called()
|
||||
assert len(mock_client.send_voice_assistant_event.call_args_list) == 2
|
||||
# Cancel task
|
||||
task.cancel()
|
||||
await task
|
||||
|
||||
# The TTS_STREAM_* events should have gone out
|
||||
assert mock_client.send_voice_assistant_event.call_args_list[-2].args == (
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START,
|
||||
{},
|
||||
)
|
||||
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END,
|
||||
{},
|
||||
)
|
||||
# No audio should have gone out
|
||||
mock_client.send_voice_assistant_audio.assert_not_called()
|
||||
assert len(mock_client.send_voice_assistant_event.call_args_list) == 2
|
||||
|
||||
# The TTS_STREAM_* events should have gone out
|
||||
assert mock_client.send_voice_assistant_event.call_args_list[-2].args == (
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START,
|
||||
{},
|
||||
)
|
||||
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END,
|
||||
{},
|
||||
)
|
||||
|
||||
|
||||
async def test_tts_format_from_media_player(
|
||||
|
Loading…
x
Reference in New Issue
Block a user