mirror of
https://github.com/home-assistant/core.git
synced 2025-07-15 17:27:10 +00:00
Handle announcement finished for ESPHome TTS response (#125625)
* Handle announcement finished for TTS response * Adjust test
This commit is contained in:
parent
970d28bce9
commit
3eed5de367
@ -14,6 +14,7 @@ import wave
|
|||||||
|
|
||||||
from aioesphomeapi import (
|
from aioesphomeapi import (
|
||||||
MediaPlayerFormatPurpose,
|
MediaPlayerFormatPurpose,
|
||||||
|
VoiceAssistantAnnounceFinished,
|
||||||
VoiceAssistantAudioSettings,
|
VoiceAssistantAudioSettings,
|
||||||
VoiceAssistantCommandFlag,
|
VoiceAssistantCommandFlag,
|
||||||
VoiceAssistantEventType,
|
VoiceAssistantEventType,
|
||||||
@ -166,6 +167,7 @@ class EsphomeAssistSatellite(
|
|||||||
handle_start=self.handle_pipeline_start,
|
handle_start=self.handle_pipeline_start,
|
||||||
handle_stop=self.handle_pipeline_stop,
|
handle_stop=self.handle_pipeline_stop,
|
||||||
handle_audio=self.handle_audio,
|
handle_audio=self.handle_audio,
|
||||||
|
handle_announcement_finished=self.handle_announcement_finished,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -174,6 +176,7 @@ class EsphomeAssistSatellite(
|
|||||||
self.cli.subscribe_voice_assistant(
|
self.cli.subscribe_voice_assistant(
|
||||||
handle_start=self.handle_pipeline_start,
|
handle_start=self.handle_pipeline_start,
|
||||||
handle_stop=self.handle_pipeline_stop,
|
handle_stop=self.handle_pipeline_stop,
|
||||||
|
handle_announcement_finished=self.handle_announcement_finished,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -194,6 +197,10 @@ class EsphomeAssistSatellite(
|
|||||||
assist_satellite.AssistSatelliteEntityFeature.ANNOUNCE
|
assist_satellite.AssistSatelliteEntityFeature.ANNOUNCE
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if not (feature_flags & VoiceAssistantFeature.SPEAKER):
|
||||||
|
# Will use media player for TTS/announcements
|
||||||
|
self._update_tts_format()
|
||||||
|
|
||||||
async def async_will_remove_from_hass(self) -> None:
|
async def async_will_remove_from_hass(self) -> None:
|
||||||
"""Run when entity will be removed from hass."""
|
"""Run when entity will be removed from hass."""
|
||||||
await super().async_will_remove_from_hass()
|
await super().async_will_remove_from_hass()
|
||||||
@ -382,6 +389,12 @@ class EsphomeAssistSatellite(
|
|||||||
timer_info.is_active,
|
timer_info.is_active,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def handle_announcement_finished(
|
||||||
|
self, announce_finished: VoiceAssistantAnnounceFinished
|
||||||
|
) -> None:
|
||||||
|
"""Handle announcement finished message (also sent for TTS)."""
|
||||||
|
self.tts_response_finished()
|
||||||
|
|
||||||
def _update_tts_format(self) -> None:
|
def _update_tts_format(self) -> None:
|
||||||
"""Update the TTS format from the first media player."""
|
"""Update the TTS format from the first media player."""
|
||||||
for supported_format in chain(*self.entry_data.media_player_formats.values()):
|
for supported_format in chain(*self.entry_data.media_player_formats.values()):
|
||||||
|
@ -19,6 +19,7 @@ from aioesphomeapi import (
|
|||||||
HomeassistantServiceCall,
|
HomeassistantServiceCall,
|
||||||
ReconnectLogic,
|
ReconnectLogic,
|
||||||
UserService,
|
UserService,
|
||||||
|
VoiceAssistantAnnounceFinished,
|
||||||
VoiceAssistantAudioSettings,
|
VoiceAssistantAudioSettings,
|
||||||
VoiceAssistantFeature,
|
VoiceAssistantFeature,
|
||||||
)
|
)
|
||||||
@ -214,6 +215,13 @@ class MockESPHomeDevice:
|
|||||||
]
|
]
|
||||||
| None
|
| None
|
||||||
)
|
)
|
||||||
|
self.voice_assistant_handle_announcement_finished_callback: (
|
||||||
|
Callable[
|
||||||
|
[VoiceAssistantAnnounceFinished],
|
||||||
|
Coroutine[Any, Any, None],
|
||||||
|
]
|
||||||
|
| None
|
||||||
|
)
|
||||||
self.device_info = device_info
|
self.device_info = device_info
|
||||||
|
|
||||||
def set_state_callback(self, state_callback: Callable[[EntityState], None]) -> None:
|
def set_state_callback(self, state_callback: Callable[[EntityState], None]) -> None:
|
||||||
@ -295,11 +303,21 @@ class MockESPHomeDevice:
|
|||||||
]
|
]
|
||||||
| None
|
| None
|
||||||
) = None,
|
) = None,
|
||||||
|
handle_announcement_finished: (
|
||||||
|
Callable[
|
||||||
|
[VoiceAssistantAnnounceFinished],
|
||||||
|
Coroutine[Any, Any, None],
|
||||||
|
]
|
||||||
|
| None
|
||||||
|
) = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set the voice assistant subscription callbacks."""
|
"""Set the voice assistant subscription callbacks."""
|
||||||
self.voice_assistant_handle_start_callback = handle_start
|
self.voice_assistant_handle_start_callback = handle_start
|
||||||
self.voice_assistant_handle_stop_callback = handle_stop
|
self.voice_assistant_handle_stop_callback = handle_stop
|
||||||
self.voice_assistant_handle_audio_callback = handle_audio
|
self.voice_assistant_handle_audio_callback = handle_audio
|
||||||
|
self.voice_assistant_handle_announcement_finished_callback = (
|
||||||
|
handle_announcement_finished
|
||||||
|
)
|
||||||
|
|
||||||
async def mock_voice_assistant_handle_start(
|
async def mock_voice_assistant_handle_start(
|
||||||
self,
|
self,
|
||||||
@ -322,6 +340,13 @@ class MockESPHomeDevice:
|
|||||||
assert self.voice_assistant_handle_audio_callback is not None
|
assert self.voice_assistant_handle_audio_callback is not None
|
||||||
await self.voice_assistant_handle_audio_callback(audio)
|
await self.voice_assistant_handle_audio_callback(audio)
|
||||||
|
|
||||||
|
async def mock_voice_assistant_handle_announcement_finished(
|
||||||
|
self, finished: VoiceAssistantAnnounceFinished
|
||||||
|
) -> None:
|
||||||
|
"""Mock voice assistant handle announcement finished."""
|
||||||
|
assert self.voice_assistant_handle_announcement_finished_callback is not None
|
||||||
|
await self.voice_assistant_handle_announcement_finished_callback(finished)
|
||||||
|
|
||||||
|
|
||||||
async def _mock_generic_device_entry(
|
async def _mock_generic_device_entry(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
@ -402,10 +427,17 @@ async def _mock_generic_device_entry(
|
|||||||
]
|
]
|
||||||
| None
|
| None
|
||||||
) = None,
|
) = None,
|
||||||
|
handle_announcement_finished: (
|
||||||
|
Callable[
|
||||||
|
[VoiceAssistantAnnounceFinished],
|
||||||
|
Coroutine[Any, Any, None],
|
||||||
|
]
|
||||||
|
| None
|
||||||
|
) = None,
|
||||||
) -> Callable[[], None]:
|
) -> Callable[[], None]:
|
||||||
"""Subscribe to voice assistant."""
|
"""Subscribe to voice assistant."""
|
||||||
mock_device.set_subscribe_voice_assistant_callbacks(
|
mock_device.set_subscribe_voice_assistant_callbacks(
|
||||||
handle_start, handle_stop, handle_audio
|
handle_start, handle_stop, handle_audio, handle_announcement_finished
|
||||||
)
|
)
|
||||||
|
|
||||||
def unsub():
|
def unsub():
|
||||||
|
@ -15,6 +15,7 @@ from aioesphomeapi import (
|
|||||||
MediaPlayerInfo,
|
MediaPlayerInfo,
|
||||||
MediaPlayerSupportedFormat,
|
MediaPlayerSupportedFormat,
|
||||||
UserService,
|
UserService,
|
||||||
|
VoiceAssistantAnnounceFinished,
|
||||||
VoiceAssistantAudioSettings,
|
VoiceAssistantAudioSettings,
|
||||||
VoiceAssistantCommandFlag,
|
VoiceAssistantCommandFlag,
|
||||||
VoiceAssistantEventType,
|
VoiceAssistantEventType,
|
||||||
@ -603,6 +604,160 @@ async def test_udp_errors() -> None:
|
|||||||
protocol.transport.sendto.assert_not_called()
|
protocol.transport.sendto.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_pipeline_media_player(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_client: APIClient,
|
||||||
|
mock_esphome_device: Callable[
|
||||||
|
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
|
||||||
|
Awaitable[MockESPHomeDevice],
|
||||||
|
],
|
||||||
|
mock_wav: bytes,
|
||||||
|
) -> None:
|
||||||
|
"""Test a complete pipeline run with the TTS response sent to a media player instead of a speaker.
|
||||||
|
|
||||||
|
This test is not as comprehensive as test_pipeline_api_audio since we're
|
||||||
|
mainly focused on tts_response_finished getting automatically called.
|
||||||
|
"""
|
||||||
|
conversation_id = "test-conversation-id"
|
||||||
|
media_url = "http://test.url"
|
||||||
|
media_id = "test-media-id"
|
||||||
|
|
||||||
|
mock_device: MockESPHomeDevice = await mock_esphome_device(
|
||||||
|
mock_client=mock_client,
|
||||||
|
entity_info=[],
|
||||||
|
user_service=[],
|
||||||
|
states=[],
|
||||||
|
device_info={
|
||||||
|
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
|
||||||
|
| VoiceAssistantFeature.API_AUDIO
|
||||||
|
},
|
||||||
|
)
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
|
||||||
|
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
|
||||||
|
assert satellite is not None
|
||||||
|
|
||||||
|
async def async_pipeline_from_audio_stream(*args, device_id, **kwargs):
|
||||||
|
stt_stream = kwargs["stt_stream"]
|
||||||
|
|
||||||
|
async for _chunk in stt_stream:
|
||||||
|
break
|
||||||
|
|
||||||
|
event_callback = kwargs["event_callback"]
|
||||||
|
|
||||||
|
# STT
|
||||||
|
event_callback(
|
||||||
|
PipelineEvent(
|
||||||
|
type=PipelineEventType.STT_START,
|
||||||
|
data={"engine": "test-stt-engine", "metadata": {}},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
event_callback(
|
||||||
|
PipelineEvent(
|
||||||
|
type=PipelineEventType.STT_END,
|
||||||
|
data={"stt_output": {"text": "test-stt-text"}},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Intent
|
||||||
|
event_callback(
|
||||||
|
PipelineEvent(
|
||||||
|
type=PipelineEventType.INTENT_START,
|
||||||
|
data={
|
||||||
|
"engine": "test-intent-engine",
|
||||||
|
"language": hass.config.language,
|
||||||
|
"intent_input": "test-intent-text",
|
||||||
|
"conversation_id": conversation_id,
|
||||||
|
"device_id": device_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
event_callback(
|
||||||
|
PipelineEvent(
|
||||||
|
type=PipelineEventType.INTENT_END,
|
||||||
|
data={"intent_output": {"conversation_id": conversation_id}},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# TTS
|
||||||
|
event_callback(
|
||||||
|
PipelineEvent(
|
||||||
|
type=PipelineEventType.TTS_START,
|
||||||
|
data={
|
||||||
|
"engine": "test-stt-engine",
|
||||||
|
"language": hass.config.language,
|
||||||
|
"voice": "test-voice",
|
||||||
|
"tts_input": "test-tts-text",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Should return mock_wav audio
|
||||||
|
event_callback(
|
||||||
|
PipelineEvent(
|
||||||
|
type=PipelineEventType.TTS_END,
|
||||||
|
data={"tts_output": {"url": media_url, "media_id": media_id}},
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
event_callback(PipelineEvent(type=PipelineEventType.RUN_END))
|
||||||
|
|
||||||
|
pipeline_finished = asyncio.Event()
|
||||||
|
original_handle_pipeline_finished = satellite.handle_pipeline_finished
|
||||||
|
|
||||||
|
def handle_pipeline_finished():
|
||||||
|
original_handle_pipeline_finished()
|
||||||
|
pipeline_finished.set()
|
||||||
|
|
||||||
|
async def async_get_media_source_audio(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
media_source_id: str,
|
||||||
|
) -> tuple[str, bytes]:
|
||||||
|
return ("wav", mock_wav)
|
||||||
|
|
||||||
|
tts_finished = asyncio.Event()
|
||||||
|
original_tts_response_finished = satellite.tts_response_finished
|
||||||
|
|
||||||
|
def tts_response_finished():
|
||||||
|
original_tts_response_finished()
|
||||||
|
tts_finished.set()
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||||
|
new=async_pipeline_from_audio_stream,
|
||||||
|
),
|
||||||
|
patch(
|
||||||
|
"homeassistant.components.tts.async_get_media_source_audio",
|
||||||
|
new=async_get_media_source_audio,
|
||||||
|
),
|
||||||
|
patch.object(satellite, "handle_pipeline_finished", handle_pipeline_finished),
|
||||||
|
patch.object(satellite, "tts_response_finished", tts_response_finished),
|
||||||
|
):
|
||||||
|
async with asyncio.timeout(1):
|
||||||
|
await satellite.handle_pipeline_start(
|
||||||
|
conversation_id=conversation_id,
|
||||||
|
flags=VoiceAssistantCommandFlag(0), # stt
|
||||||
|
audio_settings=VoiceAssistantAudioSettings(),
|
||||||
|
wake_word_phrase="",
|
||||||
|
)
|
||||||
|
|
||||||
|
await satellite.handle_pipeline_stop(abort=False)
|
||||||
|
await pipeline_finished.wait()
|
||||||
|
|
||||||
|
assert satellite.state == AssistSatelliteState.RESPONDING
|
||||||
|
|
||||||
|
# Will trigger tts_response_finished
|
||||||
|
await mock_device.mock_voice_assistant_handle_announcement_finished(
|
||||||
|
VoiceAssistantAnnounceFinished(success=True)
|
||||||
|
)
|
||||||
|
await tts_finished.wait()
|
||||||
|
|
||||||
|
assert satellite.state == AssistSatelliteState.LISTENING_WAKE_WORD
|
||||||
|
|
||||||
|
|
||||||
async def test_timer_events(
|
async def test_timer_events(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
device_registry: dr.DeviceRegistry,
|
device_registry: dr.DeviceRegistry,
|
||||||
@ -952,6 +1107,7 @@ async def test_announce_message(
|
|||||||
async def send_voice_assistant_announcement_await_response(
|
async def send_voice_assistant_announcement_await_response(
|
||||||
media_id: str, timeout: float, text: str
|
media_id: str, timeout: float, text: str
|
||||||
):
|
):
|
||||||
|
assert satellite.state == AssistSatelliteState.RESPONDING
|
||||||
assert media_id == "https://www.home-assistant.io/resolved.mp3"
|
assert media_id == "https://www.home-assistant.io/resolved.mp3"
|
||||||
assert text == "test-text"
|
assert text == "test-text"
|
||||||
|
|
||||||
@ -983,6 +1139,7 @@ async def test_announce_message(
|
|||||||
blocking=True,
|
blocking=True,
|
||||||
)
|
)
|
||||||
await done.wait()
|
await done.wait()
|
||||||
|
assert satellite.state == AssistSatelliteState.LISTENING_WAKE_WORD
|
||||||
|
|
||||||
|
|
||||||
async def test_announce_media_id(
|
async def test_announce_media_id(
|
||||||
@ -1016,6 +1173,7 @@ async def test_announce_media_id(
|
|||||||
async def send_voice_assistant_announcement_await_response(
|
async def send_voice_assistant_announcement_await_response(
|
||||||
media_id: str, timeout: float, text: str
|
media_id: str, timeout: float, text: str
|
||||||
):
|
):
|
||||||
|
assert satellite.state == AssistSatelliteState.RESPONDING
|
||||||
assert media_id == "https://www.home-assistant.io/resolved.mp3"
|
assert media_id == "https://www.home-assistant.io/resolved.mp3"
|
||||||
|
|
||||||
done.set()
|
done.set()
|
||||||
@ -1038,6 +1196,7 @@ async def test_announce_media_id(
|
|||||||
blocking=True,
|
blocking=True,
|
||||||
)
|
)
|
||||||
await done.wait()
|
await done.wait()
|
||||||
|
assert satellite.state == AssistSatelliteState.LISTENING_WAKE_WORD
|
||||||
|
|
||||||
|
|
||||||
async def test_satellite_unloaded_on_disconnect(
|
async def test_satellite_unloaded_on_disconnect(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user