mirror of
https://github.com/home-assistant/core.git
synced 2025-07-18 18:57:06 +00:00
Migrate ESPHome to use token instead of media source ID for legacy Assist Pipelines (#139665)
Migrate legacy ESPHome devices to use TTS token Co-authored-by: Franck Nijhof <git@frenck.dev>
This commit is contained in:
parent
ec20e41836
commit
86be626c69
@ -310,12 +310,13 @@ class EsphomeAssistSatellite(
|
|||||||
self.entry_data.api_version
|
self.entry_data.api_version
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if feature_flags & VoiceAssistantFeature.SPEAKER:
|
if feature_flags & VoiceAssistantFeature.SPEAKER and (
|
||||||
media_id = tts_output["media_id"]
|
stream := tts.async_get_stream(self.hass, tts_output["token"])
|
||||||
|
):
|
||||||
self._tts_streaming_task = (
|
self._tts_streaming_task = (
|
||||||
self.config_entry.async_create_background_task(
|
self.config_entry.async_create_background_task(
|
||||||
self.hass,
|
self.hass,
|
||||||
self._stream_tts_audio(media_id),
|
self._stream_tts_audio(stream),
|
||||||
"esphome_voice_assistant_tts",
|
"esphome_voice_assistant_tts",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -564,7 +565,7 @@ class EsphomeAssistSatellite(
|
|||||||
|
|
||||||
async def _stream_tts_audio(
|
async def _stream_tts_audio(
|
||||||
self,
|
self,
|
||||||
media_id: str,
|
tts_result: tts.ResultStream,
|
||||||
sample_rate: int = 16000,
|
sample_rate: int = 16000,
|
||||||
sample_width: int = 2,
|
sample_width: int = 2,
|
||||||
sample_channels: int = 1,
|
sample_channels: int = 1,
|
||||||
@ -579,15 +580,14 @@ class EsphomeAssistSatellite(
|
|||||||
if not self._is_running:
|
if not self._is_running:
|
||||||
return
|
return
|
||||||
|
|
||||||
extension, data = await tts.async_get_media_source_audio(
|
if tts_result.extension != "wav":
|
||||||
self.hass,
|
_LOGGER.error(
|
||||||
media_id,
|
"Only WAV audio can be streamed, got %s", tts_result.extension
|
||||||
)
|
)
|
||||||
|
|
||||||
if extension != "wav":
|
|
||||||
_LOGGER.error("Only WAV audio can be streamed, got %s", extension)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
|
data = b"".join([chunk async for chunk in tts_result.async_stream_result()])
|
||||||
|
|
||||||
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
|
with io.BytesIO(data) as wav_io, wave.open(wav_io, "rb") as wav_file:
|
||||||
if (
|
if (
|
||||||
(wav_file.getframerate() != sample_rate)
|
(wav_file.getframerate() != sample_rate)
|
||||||
|
@ -58,6 +58,7 @@ from homeassistant.helpers import (
|
|||||||
intent as intent_helper,
|
intent as intent_helper,
|
||||||
)
|
)
|
||||||
from homeassistant.helpers.entity_component import EntityComponent
|
from homeassistant.helpers.entity_component import EntityComponent
|
||||||
|
from homeassistant.helpers.network import get_url
|
||||||
|
|
||||||
from .conftest import MockESPHomeDevice
|
from .conftest import MockESPHomeDevice
|
||||||
|
|
||||||
@ -133,8 +134,6 @@ async def test_pipeline_api_audio(
|
|||||||
) -> None:
|
) -> None:
|
||||||
"""Test a complete pipeline run with API audio (over the TCP connection)."""
|
"""Test a complete pipeline run with API audio (over the TCP connection)."""
|
||||||
conversation_id = "test-conversation-id"
|
conversation_id = "test-conversation-id"
|
||||||
media_url = "http://test.url"
|
|
||||||
media_id = "test-media-id"
|
|
||||||
|
|
||||||
mock_device: MockESPHomeDevice = await mock_esphome_device(
|
mock_device: MockESPHomeDevice = await mock_esphome_device(
|
||||||
mock_client=mock_client,
|
mock_client=mock_client,
|
||||||
@ -328,15 +327,22 @@ async def test_pipeline_api_audio(
|
|||||||
assert satellite.state == AssistSatelliteState.RESPONDING
|
assert satellite.state == AssistSatelliteState.RESPONDING
|
||||||
|
|
||||||
# Should return mock_wav audio
|
# Should return mock_wav audio
|
||||||
|
mock_tts_result_stream = MockResultStream(hass, "wav", mock_wav)
|
||||||
event_callback(
|
event_callback(
|
||||||
PipelineEvent(
|
PipelineEvent(
|
||||||
type=PipelineEventType.TTS_END,
|
type=PipelineEventType.TTS_END,
|
||||||
data={"tts_output": {"url": media_url, "media_id": media_id}},
|
data={
|
||||||
|
"tts_output": {
|
||||||
|
"media_id": "test-media-id",
|
||||||
|
"url": mock_tts_result_stream.url,
|
||||||
|
"token": mock_tts_result_stream.token,
|
||||||
|
}
|
||||||
|
},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
|
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
|
||||||
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END,
|
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END,
|
||||||
{"url": media_url},
|
{"url": get_url(hass) + mock_tts_result_stream.url},
|
||||||
)
|
)
|
||||||
|
|
||||||
event_callback(PipelineEvent(type=PipelineEventType.RUN_END))
|
event_callback(PipelineEvent(type=PipelineEventType.RUN_END))
|
||||||
@ -355,12 +361,6 @@ async def test_pipeline_api_audio(
|
|||||||
original_handle_pipeline_finished()
|
original_handle_pipeline_finished()
|
||||||
pipeline_finished.set()
|
pipeline_finished.set()
|
||||||
|
|
||||||
async def async_get_media_source_audio(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
media_source_id: str,
|
|
||||||
) -> tuple[str, bytes]:
|
|
||||||
return ("wav", mock_wav)
|
|
||||||
|
|
||||||
tts_finished = asyncio.Event()
|
tts_finished = asyncio.Event()
|
||||||
original_tts_response_finished = satellite.tts_response_finished
|
original_tts_response_finished = satellite.tts_response_finished
|
||||||
|
|
||||||
@ -373,10 +373,6 @@ async def test_pipeline_api_audio(
|
|||||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||||
new=async_pipeline_from_audio_stream,
|
new=async_pipeline_from_audio_stream,
|
||||||
),
|
),
|
||||||
patch(
|
|
||||||
"homeassistant.components.tts.async_get_media_source_audio",
|
|
||||||
new=async_get_media_source_audio,
|
|
||||||
),
|
|
||||||
patch.object(satellite, "handle_pipeline_finished", handle_pipeline_finished),
|
patch.object(satellite, "handle_pipeline_finished", handle_pipeline_finished),
|
||||||
patch.object(satellite, "_stream_tts_audio", _stream_tts_audio),
|
patch.object(satellite, "_stream_tts_audio", _stream_tts_audio),
|
||||||
patch.object(satellite, "tts_response_finished", tts_response_finished),
|
patch.object(satellite, "tts_response_finished", tts_response_finished),
|
||||||
@ -434,8 +430,6 @@ async def test_pipeline_udp_audio(
|
|||||||
mainly focused on the UDP server.
|
mainly focused on the UDP server.
|
||||||
"""
|
"""
|
||||||
conversation_id = "test-conversation-id"
|
conversation_id = "test-conversation-id"
|
||||||
media_url = "http://test.url"
|
|
||||||
media_id = "test-media-id"
|
|
||||||
|
|
||||||
mock_device: MockESPHomeDevice = await mock_esphome_device(
|
mock_device: MockESPHomeDevice = await mock_esphome_device(
|
||||||
mock_client=mock_client,
|
mock_client=mock_client,
|
||||||
@ -522,10 +516,17 @@ async def test_pipeline_udp_audio(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Should return mock_wav audio
|
# Should return mock_wav audio
|
||||||
|
mock_tts_result_stream = MockResultStream(hass, "wav", mock_wav)
|
||||||
event_callback(
|
event_callback(
|
||||||
PipelineEvent(
|
PipelineEvent(
|
||||||
type=PipelineEventType.TTS_END,
|
type=PipelineEventType.TTS_END,
|
||||||
data={"tts_output": {"url": media_url, "media_id": media_id}},
|
data={
|
||||||
|
"tts_output": {
|
||||||
|
"media_id": "test-media-id",
|
||||||
|
"url": mock_tts_result_stream.url,
|
||||||
|
"token": mock_tts_result_stream.token,
|
||||||
|
}
|
||||||
|
},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -538,12 +539,6 @@ async def test_pipeline_udp_audio(
|
|||||||
original_handle_pipeline_finished()
|
original_handle_pipeline_finished()
|
||||||
pipeline_finished.set()
|
pipeline_finished.set()
|
||||||
|
|
||||||
async def async_get_media_source_audio(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
media_source_id: str,
|
|
||||||
) -> tuple[str, bytes]:
|
|
||||||
return ("wav", mock_wav)
|
|
||||||
|
|
||||||
tts_finished = asyncio.Event()
|
tts_finished = asyncio.Event()
|
||||||
original_tts_response_finished = satellite.tts_response_finished
|
original_tts_response_finished = satellite.tts_response_finished
|
||||||
|
|
||||||
@ -567,10 +562,6 @@ async def test_pipeline_udp_audio(
|
|||||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||||
new=async_pipeline_from_audio_stream,
|
new=async_pipeline_from_audio_stream,
|
||||||
),
|
),
|
||||||
patch(
|
|
||||||
"homeassistant.components.tts.async_get_media_source_audio",
|
|
||||||
new=async_get_media_source_audio,
|
|
||||||
),
|
|
||||||
patch.object(satellite, "handle_pipeline_finished", handle_pipeline_finished),
|
patch.object(satellite, "handle_pipeline_finished", handle_pipeline_finished),
|
||||||
patch.object(satellite, "tts_response_finished", tts_response_finished),
|
patch.object(satellite, "tts_response_finished", tts_response_finished),
|
||||||
):
|
):
|
||||||
@ -652,8 +643,6 @@ async def test_pipeline_media_player(
|
|||||||
mainly focused on tts_response_finished getting automatically called.
|
mainly focused on tts_response_finished getting automatically called.
|
||||||
"""
|
"""
|
||||||
conversation_id = "test-conversation-id"
|
conversation_id = "test-conversation-id"
|
||||||
media_url = "http://test.url"
|
|
||||||
media_id = "test-media-id"
|
|
||||||
|
|
||||||
mock_device: MockESPHomeDevice = await mock_esphome_device(
|
mock_device: MockESPHomeDevice = await mock_esphome_device(
|
||||||
mock_client=mock_client,
|
mock_client=mock_client,
|
||||||
@ -733,10 +722,17 @@ async def test_pipeline_media_player(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Should return mock_wav audio
|
# Should return mock_wav audio
|
||||||
|
mock_tts_result_stream = MockResultStream(hass, "wav", mock_wav)
|
||||||
event_callback(
|
event_callback(
|
||||||
PipelineEvent(
|
PipelineEvent(
|
||||||
type=PipelineEventType.TTS_END,
|
type=PipelineEventType.TTS_END,
|
||||||
data={"tts_output": {"url": media_url, "media_id": media_id}},
|
data={
|
||||||
|
"tts_output": {
|
||||||
|
"media_id": "test-media-id",
|
||||||
|
"url": mock_tts_result_stream.url,
|
||||||
|
"token": mock_tts_result_stream.token,
|
||||||
|
}
|
||||||
|
},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -749,12 +745,6 @@ async def test_pipeline_media_player(
|
|||||||
original_handle_pipeline_finished()
|
original_handle_pipeline_finished()
|
||||||
pipeline_finished.set()
|
pipeline_finished.set()
|
||||||
|
|
||||||
async def async_get_media_source_audio(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
media_source_id: str,
|
|
||||||
) -> tuple[str, bytes]:
|
|
||||||
return ("wav", mock_wav)
|
|
||||||
|
|
||||||
tts_finished = asyncio.Event()
|
tts_finished = asyncio.Event()
|
||||||
original_tts_response_finished = satellite.tts_response_finished
|
original_tts_response_finished = satellite.tts_response_finished
|
||||||
|
|
||||||
@ -767,10 +757,6 @@ async def test_pipeline_media_player(
|
|||||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||||
new=async_pipeline_from_audio_stream,
|
new=async_pipeline_from_audio_stream,
|
||||||
),
|
),
|
||||||
patch(
|
|
||||||
"homeassistant.components.tts.async_get_media_source_audio",
|
|
||||||
new=async_get_media_source_audio,
|
|
||||||
),
|
|
||||||
patch.object(satellite, "handle_pipeline_finished", handle_pipeline_finished),
|
patch.object(satellite, "handle_pipeline_finished", handle_pipeline_finished),
|
||||||
patch.object(satellite, "tts_response_finished", tts_response_finished),
|
patch.object(satellite, "tts_response_finished", tts_response_finished),
|
||||||
):
|
):
|
||||||
@ -944,80 +930,63 @@ async def test_streaming_tts_errors(
|
|||||||
|
|
||||||
# Should not stream if not running
|
# Should not stream if not running
|
||||||
satellite._is_running = False
|
satellite._is_running = False
|
||||||
await satellite._stream_tts_audio("test-media-id")
|
await satellite._stream_tts_audio(MockResultStream(hass, "wav", mock_wav))
|
||||||
mock_client.send_voice_assistant_audio.assert_not_called()
|
mock_client.send_voice_assistant_audio.assert_not_called()
|
||||||
satellite._is_running = True
|
satellite._is_running = True
|
||||||
|
|
||||||
# Should only stream WAV
|
# Should only stream WAV
|
||||||
async def get_mp3(
|
await satellite._stream_tts_audio(MockResultStream(hass, "mp3", b""))
|
||||||
hass: HomeAssistant,
|
mock_client.send_voice_assistant_audio.assert_not_called()
|
||||||
media_source_id: str,
|
|
||||||
) -> tuple[str, bytes]:
|
|
||||||
return ("mp3", b"")
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"homeassistant.components.tts.async_get_media_source_audio", new=get_mp3
|
|
||||||
):
|
|
||||||
await satellite._stream_tts_audio("test-media-id")
|
|
||||||
mock_client.send_voice_assistant_audio.assert_not_called()
|
|
||||||
|
|
||||||
# Needs to be the correct sample rate, etc.
|
# Needs to be the correct sample rate, etc.
|
||||||
async def get_bad_wav(
|
with io.BytesIO() as wav_io:
|
||||||
hass: HomeAssistant,
|
with wave.open(wav_io, "wb") as wav_file:
|
||||||
media_source_id: str,
|
wav_file.setframerate(48000)
|
||||||
) -> tuple[str, bytes]:
|
wav_file.setsampwidth(2)
|
||||||
with io.BytesIO() as wav_io:
|
wav_file.setnchannels(1)
|
||||||
with wave.open(wav_io, "wb") as wav_file:
|
wav_file.writeframes(b"test-wav")
|
||||||
wav_file.setframerate(48000)
|
|
||||||
wav_file.setsampwidth(2)
|
|
||||||
wav_file.setnchannels(1)
|
|
||||||
wav_file.writeframes(b"test-wav")
|
|
||||||
|
|
||||||
return ("wav", wav_io.getvalue())
|
mock_tts_result_stream = MockResultStream(hass, "wav", wav_io.getvalue())
|
||||||
|
|
||||||
with patch(
|
await satellite._stream_tts_audio(mock_tts_result_stream)
|
||||||
"homeassistant.components.tts.async_get_media_source_audio", new=get_bad_wav
|
mock_client.send_voice_assistant_audio.assert_not_called()
|
||||||
):
|
|
||||||
await satellite._stream_tts_audio("test-media-id")
|
|
||||||
mock_client.send_voice_assistant_audio.assert_not_called()
|
|
||||||
|
|
||||||
# Check that TTS_STREAM_* events still get sent after cancel
|
# Check that TTS_STREAM_* events still get sent after cancel
|
||||||
media_fetched = asyncio.Event()
|
media_fetched = asyncio.Event()
|
||||||
|
|
||||||
async def get_slow_wav(
|
mock_tts_result_stream = MockResultStream(hass, "wav", b"")
|
||||||
hass: HomeAssistant,
|
|
||||||
media_source_id: str,
|
async def async_stream_result_slowly():
|
||||||
) -> tuple[str, bytes]:
|
|
||||||
media_fetched.set()
|
media_fetched.set()
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
return ("wav", mock_wav)
|
yield mock_wav
|
||||||
|
|
||||||
|
mock_tts_result_stream.async_stream_result = async_stream_result_slowly
|
||||||
|
|
||||||
mock_client.send_voice_assistant_event.reset_mock()
|
mock_client.send_voice_assistant_event.reset_mock()
|
||||||
with patch(
|
|
||||||
"homeassistant.components.tts.async_get_media_source_audio", new=get_slow_wav
|
|
||||||
):
|
|
||||||
task = asyncio.create_task(satellite._stream_tts_audio("test-media-id"))
|
|
||||||
async with asyncio.timeout(1):
|
|
||||||
# Wait for media to be fetched
|
|
||||||
await media_fetched.wait()
|
|
||||||
|
|
||||||
# Cancel task
|
task = asyncio.create_task(satellite._stream_tts_audio(mock_tts_result_stream))
|
||||||
task.cancel()
|
async with asyncio.timeout(1):
|
||||||
await task
|
# Wait for media to be fetched
|
||||||
|
await media_fetched.wait()
|
||||||
|
|
||||||
# No audio should have gone out
|
# Cancel task
|
||||||
mock_client.send_voice_assistant_audio.assert_not_called()
|
task.cancel()
|
||||||
assert len(mock_client.send_voice_assistant_event.call_args_list) == 2
|
await task
|
||||||
|
|
||||||
# The TTS_STREAM_* events should have gone out
|
# No audio should have gone out
|
||||||
assert mock_client.send_voice_assistant_event.call_args_list[-2].args == (
|
mock_client.send_voice_assistant_audio.assert_not_called()
|
||||||
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START,
|
assert len(mock_client.send_voice_assistant_event.call_args_list) == 2
|
||||||
{},
|
|
||||||
)
|
# The TTS_STREAM_* events should have gone out
|
||||||
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
|
assert mock_client.send_voice_assistant_event.call_args_list[-2].args == (
|
||||||
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END,
|
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START,
|
||||||
{},
|
{},
|
||||||
)
|
)
|
||||||
|
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
|
||||||
|
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END,
|
||||||
|
{},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_tts_format_from_media_player(
|
async def test_tts_format_from_media_player(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user