mirror of
https://github.com/home-assistant/core.git
synced 2025-07-26 06:37:52 +00:00
Migrate VoIP to use Assist Pipeline TTS tokens (#139671)
* Migrate VoIP to use pipeline token * migrate announcements to use TTS token
This commit is contained in:
parent
871a7c87bf
commit
8aa30b0ccb
@ -408,10 +408,18 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
|
|||||||
"""Play an announcement once."""
|
"""Play an announcement once."""
|
||||||
_LOGGER.debug("Playing announcement")
|
_LOGGER.debug("Playing announcement")
|
||||||
|
|
||||||
try:
|
if announcement.tts_token is None:
|
||||||
await asyncio.sleep(_ANNOUNCEMENT_BEFORE_DELAY)
|
_LOGGER.error("Only TTS announcements are supported")
|
||||||
await self._send_tts(announcement.original_media_id, wait_for_tone=False)
|
return
|
||||||
|
|
||||||
|
await asyncio.sleep(_ANNOUNCEMENT_BEFORE_DELAY)
|
||||||
|
stream = tts.async_get_stream(self.hass, announcement.tts_token)
|
||||||
|
if stream is None:
|
||||||
|
_LOGGER.error("TTS stream no longer available")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._send_tts(stream, wait_for_tone=False)
|
||||||
if not self._run_pipeline_after_announce:
|
if not self._run_pipeline_after_announce:
|
||||||
# Delay before looping announcement
|
# Delay before looping announcement
|
||||||
await asyncio.sleep(_ANNOUNCEMENT_AFTER_DELAY)
|
await asyncio.sleep(_ANNOUNCEMENT_AFTER_DELAY)
|
||||||
@ -442,11 +450,14 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
|
|||||||
)
|
)
|
||||||
elif event.type == PipelineEventType.TTS_END:
|
elif event.type == PipelineEventType.TTS_END:
|
||||||
# Send TTS audio to caller over RTP
|
# Send TTS audio to caller over RTP
|
||||||
if event.data and (tts_output := event.data["tts_output"]):
|
if (
|
||||||
media_id = tts_output["media_id"]
|
event.data
|
||||||
|
and (tts_output := event.data["tts_output"])
|
||||||
|
and (stream := tts.async_get_stream(self.hass, tts_output["token"]))
|
||||||
|
):
|
||||||
self.config_entry.async_create_background_task(
|
self.config_entry.async_create_background_task(
|
||||||
self.hass,
|
self.hass,
|
||||||
self._send_tts(media_id),
|
self._send_tts(tts_stream=stream),
|
||||||
"voip_pipeline_tts",
|
"voip_pipeline_tts",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@ -457,19 +468,22 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
|
|||||||
self._pipeline_had_error = True
|
self._pipeline_had_error = True
|
||||||
_LOGGER.warning(event)
|
_LOGGER.warning(event)
|
||||||
|
|
||||||
async def _send_tts(self, media_id: str, wait_for_tone: bool = True) -> None:
|
async def _send_tts(
|
||||||
|
self,
|
||||||
|
tts_stream: tts.ResultStream,
|
||||||
|
wait_for_tone: bool = True,
|
||||||
|
) -> None:
|
||||||
"""Send TTS audio to caller via RTP."""
|
"""Send TTS audio to caller via RTP."""
|
||||||
try:
|
try:
|
||||||
if self.transport is None:
|
if self.transport is None:
|
||||||
return # not connected
|
return # not connected
|
||||||
|
|
||||||
extension, data = await tts.async_get_media_source_audio(
|
data = b"".join([chunk async for chunk in tts_stream.async_stream_result()])
|
||||||
self.hass,
|
|
||||||
media_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if extension != "wav":
|
if tts_stream.extension != "wav":
|
||||||
raise ValueError(f"Only WAV audio can be streamed, got {extension}")
|
raise ValueError(
|
||||||
|
f"Only TTS WAV audio can be streamed, got {tts_stream.extension}"
|
||||||
|
)
|
||||||
|
|
||||||
if wait_for_tone and ((self._tones & Tones.PROCESSING) == Tones.PROCESSING):
|
if wait_for_tone and ((self._tones & Tones.PROCESSING) == Tones.PROCESSING):
|
||||||
# Don't overlap TTS and processing beep
|
# Don't overlap TTS and processing beep
|
||||||
|
@ -38,12 +38,12 @@ def mock_tts_cache_dir_autouse(mock_tts_cache_dir: Path) -> None:
|
|||||||
"""Mock the TTS cache dir with empty dir."""
|
"""Mock the TTS cache dir with empty dir."""
|
||||||
|
|
||||||
|
|
||||||
def _empty_wav() -> bytes:
|
def _empty_wav(framerate=16000) -> bytes:
|
||||||
"""Return bytes of an empty WAV file."""
|
"""Return bytes of an empty WAV file."""
|
||||||
with io.BytesIO() as wav_io:
|
with io.BytesIO() as wav_io:
|
||||||
wav_file: wave.Wave_write = wave.open(wav_io, "wb")
|
wav_file: wave.Wave_write = wave.open(wav_io, "wb")
|
||||||
with wav_file:
|
with wav_file:
|
||||||
wav_file.setframerate(16000)
|
wav_file.setframerate(framerate)
|
||||||
wav_file.setsampwidth(2)
|
wav_file.setsampwidth(2)
|
||||||
wav_file.setnchannels(1)
|
wav_file.setnchannels(1)
|
||||||
|
|
||||||
@ -307,10 +307,11 @@ async def test_pipeline(
|
|||||||
assert satellite.state == AssistSatelliteState.RESPONDING
|
assert satellite.state == AssistSatelliteState.RESPONDING
|
||||||
|
|
||||||
# Proceed with media output
|
# Proceed with media output
|
||||||
|
mock_tts_result_stream = MockResultStream(hass, "wav", _empty_wav())
|
||||||
event_callback(
|
event_callback(
|
||||||
assist_pipeline.PipelineEvent(
|
assist_pipeline.PipelineEvent(
|
||||||
type=assist_pipeline.PipelineEventType.TTS_END,
|
type=assist_pipeline.PipelineEventType.TTS_END,
|
||||||
data={"tts_output": {"media_id": _MEDIA_ID}},
|
data={"tts_output": {"token": mock_tts_result_stream.token}},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -326,22 +327,11 @@ async def test_pipeline(
|
|||||||
original_tts_response_finished()
|
original_tts_response_finished()
|
||||||
done.set()
|
done.set()
|
||||||
|
|
||||||
async def async_get_media_source_audio(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
media_source_id: str,
|
|
||||||
) -> tuple[str, bytes]:
|
|
||||||
assert media_source_id == _MEDIA_ID
|
|
||||||
return ("wav", _empty_wav())
|
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||||
new=async_pipeline_from_audio_stream,
|
new=async_pipeline_from_audio_stream,
|
||||||
),
|
),
|
||||||
patch(
|
|
||||||
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
|
|
||||||
new=async_get_media_source_audio,
|
|
||||||
),
|
|
||||||
patch.object(satellite, "tts_response_finished", tts_response_finished),
|
patch.object(satellite, "tts_response_finished", tts_response_finished),
|
||||||
):
|
):
|
||||||
satellite._tones = Tones(0)
|
satellite._tones = Tones(0)
|
||||||
@ -457,10 +447,11 @@ async def test_tts_timeout(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Proceed with media output
|
# Proceed with media output
|
||||||
|
mock_tts_result_stream = MockResultStream(hass, "wav", _empty_wav())
|
||||||
event_callback(
|
event_callback(
|
||||||
assist_pipeline.PipelineEvent(
|
assist_pipeline.PipelineEvent(
|
||||||
type=assist_pipeline.PipelineEventType.TTS_END,
|
type=assist_pipeline.PipelineEventType.TTS_END,
|
||||||
data={"tts_output": {"media_id": _MEDIA_ID}},
|
data={"tts_output": {"token": mock_tts_result_stream.token}},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -474,22 +465,9 @@ async def test_tts_timeout(
|
|||||||
# Block here to force a timeout in _send_tts
|
# Block here to force a timeout in _send_tts
|
||||||
await asyncio.sleep(2)
|
await asyncio.sleep(2)
|
||||||
|
|
||||||
async def async_get_media_source_audio(
|
with patch(
|
||||||
hass: HomeAssistant,
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||||
media_source_id: str,
|
new=async_pipeline_from_audio_stream,
|
||||||
) -> tuple[str, bytes]:
|
|
||||||
# Should time out immediately
|
|
||||||
return ("wav", _empty_wav())
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch(
|
|
||||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
|
||||||
new=async_pipeline_from_audio_stream,
|
|
||||||
),
|
|
||||||
patch(
|
|
||||||
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
|
|
||||||
new=async_get_media_source_audio,
|
|
||||||
),
|
|
||||||
):
|
):
|
||||||
satellite._tts_extra_timeout = 0.001
|
satellite._tts_extra_timeout = 0.001
|
||||||
for tone in Tones:
|
for tone in Tones:
|
||||||
@ -568,29 +546,18 @@ async def test_tts_wrong_extension(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Proceed with media output
|
# Proceed with media output
|
||||||
|
# Should fail because it's not "wav"
|
||||||
|
mock_tts_result_stream = MockResultStream(hass, "mp3", b"")
|
||||||
event_callback(
|
event_callback(
|
||||||
assist_pipeline.PipelineEvent(
|
assist_pipeline.PipelineEvent(
|
||||||
type=assist_pipeline.PipelineEventType.TTS_END,
|
type=assist_pipeline.PipelineEventType.TTS_END,
|
||||||
data={"tts_output": {"media_id": _MEDIA_ID}},
|
data={"tts_output": {"token": mock_tts_result_stream.token}},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def async_get_media_source_audio(
|
with patch(
|
||||||
hass: HomeAssistant,
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||||
media_source_id: str,
|
new=async_pipeline_from_audio_stream,
|
||||||
) -> tuple[str, bytes]:
|
|
||||||
# Should fail because it's not "wav"
|
|
||||||
return ("mp3", b"")
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch(
|
|
||||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
|
||||||
new=async_pipeline_from_audio_stream,
|
|
||||||
),
|
|
||||||
patch(
|
|
||||||
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
|
|
||||||
new=async_get_media_source_audio,
|
|
||||||
),
|
|
||||||
):
|
):
|
||||||
satellite.transport = Mock()
|
satellite.transport = Mock()
|
||||||
|
|
||||||
@ -663,36 +630,18 @@ async def test_tts_wrong_wav_format(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Proceed with media output
|
# Proceed with media output
|
||||||
|
# Should fail because it's not 16Khz
|
||||||
|
mock_tts_result_stream = MockResultStream(hass, "wav", _empty_wav(22050))
|
||||||
event_callback(
|
event_callback(
|
||||||
assist_pipeline.PipelineEvent(
|
assist_pipeline.PipelineEvent(
|
||||||
type=assist_pipeline.PipelineEventType.TTS_END,
|
type=assist_pipeline.PipelineEventType.TTS_END,
|
||||||
data={"tts_output": {"media_id": _MEDIA_ID}},
|
data={"tts_output": {"token": mock_tts_result_stream.token}},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def async_get_media_source_audio(
|
with patch(
|
||||||
hass: HomeAssistant,
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||||
media_source_id: str,
|
new=async_pipeline_from_audio_stream,
|
||||||
) -> tuple[str, bytes]:
|
|
||||||
# Should fail because it's not 16Khz, 16-bit mono
|
|
||||||
with io.BytesIO() as wav_io:
|
|
||||||
wav_file: wave.Wave_write = wave.open(wav_io, "wb")
|
|
||||||
with wav_file:
|
|
||||||
wav_file.setframerate(22050)
|
|
||||||
wav_file.setsampwidth(2)
|
|
||||||
wav_file.setnchannels(2)
|
|
||||||
|
|
||||||
return ("wav", wav_io.getvalue())
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch(
|
|
||||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
|
||||||
new=async_pipeline_from_audio_stream,
|
|
||||||
),
|
|
||||||
patch(
|
|
||||||
"homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio",
|
|
||||||
new=async_get_media_source_audio,
|
|
||||||
),
|
|
||||||
):
|
):
|
||||||
satellite.transport = Mock()
|
satellite.transport = Mock()
|
||||||
|
|
||||||
@ -878,10 +827,11 @@ async def test_announce(
|
|||||||
assert err.value.translation_domain == "voip"
|
assert err.value.translation_domain == "voip"
|
||||||
assert err.value.translation_key == "non_tts_announcement"
|
assert err.value.translation_key == "non_tts_announcement"
|
||||||
|
|
||||||
|
mock_tts_result_stream = MockResultStream(hass, "wav", _empty_wav())
|
||||||
announcement = assist_satellite.AssistSatelliteAnnouncement(
|
announcement = assist_satellite.AssistSatelliteAnnouncement(
|
||||||
message="test announcement",
|
message="test announcement",
|
||||||
media_id=_MEDIA_ID,
|
media_id=_MEDIA_ID,
|
||||||
tts_token="test-token",
|
tts_token=mock_tts_result_stream.token,
|
||||||
original_media_id=_MEDIA_ID,
|
original_media_id=_MEDIA_ID,
|
||||||
media_id_source="tts",
|
media_id_source="tts",
|
||||||
)
|
)
|
||||||
@ -907,7 +857,9 @@ async def test_announce(
|
|||||||
async with asyncio.timeout(1):
|
async with asyncio.timeout(1):
|
||||||
await announce_task
|
await announce_task
|
||||||
|
|
||||||
mock_send_tts.assert_called_once_with(_MEDIA_ID, wait_for_tone=False)
|
mock_send_tts.assert_called_once_with(
|
||||||
|
mock_tts_result_stream, wait_for_tone=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("socket_enabled")
|
@pytest.mark.usefixtures("socket_enabled")
|
||||||
@ -926,10 +878,11 @@ async def test_voip_id_is_ip_address(
|
|||||||
& assist_satellite.AssistSatelliteEntityFeature.ANNOUNCE
|
& assist_satellite.AssistSatelliteEntityFeature.ANNOUNCE
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mock_tts_result_stream = MockResultStream(hass, "wav", _empty_wav())
|
||||||
announcement = assist_satellite.AssistSatelliteAnnouncement(
|
announcement = assist_satellite.AssistSatelliteAnnouncement(
|
||||||
message="test announcement",
|
message="test announcement",
|
||||||
media_id=_MEDIA_ID,
|
media_id=_MEDIA_ID,
|
||||||
tts_token="test-token",
|
tts_token=mock_tts_result_stream.token,
|
||||||
original_media_id=_MEDIA_ID,
|
original_media_id=_MEDIA_ID,
|
||||||
media_id_source="tts",
|
media_id_source="tts",
|
||||||
)
|
)
|
||||||
@ -960,7 +913,9 @@ async def test_voip_id_is_ip_address(
|
|||||||
async with asyncio.timeout(1):
|
async with asyncio.timeout(1):
|
||||||
await announce_task
|
await announce_task
|
||||||
|
|
||||||
mock_send_tts.assert_called_once_with(_MEDIA_ID, wait_for_tone=False)
|
mock_send_tts.assert_called_once_with(
|
||||||
|
mock_tts_result_stream, wait_for_tone=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.usefixtures("socket_enabled")
|
@pytest.mark.usefixtures("socket_enabled")
|
||||||
@ -979,10 +934,11 @@ async def test_announce_timeout(
|
|||||||
& assist_satellite.AssistSatelliteEntityFeature.ANNOUNCE
|
& assist_satellite.AssistSatelliteEntityFeature.ANNOUNCE
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mock_tts_result_stream = MockResultStream(hass, "wav", _empty_wav())
|
||||||
announcement = assist_satellite.AssistSatelliteAnnouncement(
|
announcement = assist_satellite.AssistSatelliteAnnouncement(
|
||||||
message="test announcement",
|
message="test announcement",
|
||||||
media_id=_MEDIA_ID,
|
media_id=_MEDIA_ID,
|
||||||
tts_token="test-token",
|
tts_token=mock_tts_result_stream.token,
|
||||||
original_media_id=_MEDIA_ID,
|
original_media_id=_MEDIA_ID,
|
||||||
media_id_source="tts",
|
media_id_source="tts",
|
||||||
)
|
)
|
||||||
@ -1020,10 +976,11 @@ async def test_start_conversation(
|
|||||||
& assist_satellite.AssistSatelliteEntityFeature.START_CONVERSATION
|
& assist_satellite.AssistSatelliteEntityFeature.START_CONVERSATION
|
||||||
)
|
)
|
||||||
|
|
||||||
|
mock_tts_result_stream = MockResultStream(hass, "wav", _empty_wav())
|
||||||
announcement = assist_satellite.AssistSatelliteAnnouncement(
|
announcement = assist_satellite.AssistSatelliteAnnouncement(
|
||||||
message="test announcement",
|
message="test announcement",
|
||||||
media_id=_MEDIA_ID,
|
media_id=_MEDIA_ID,
|
||||||
tts_token="test-token",
|
tts_token=mock_tts_result_stream.token,
|
||||||
original_media_id=_MEDIA_ID,
|
original_media_id=_MEDIA_ID,
|
||||||
media_id_source="tts",
|
media_id_source="tts",
|
||||||
)
|
)
|
||||||
@ -1061,10 +1018,11 @@ async def test_start_conversation(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Proceed with media output
|
# Proceed with media output
|
||||||
|
mock_tts_result_stream = MockResultStream(hass, "wav", _empty_wav())
|
||||||
event_callback(
|
event_callback(
|
||||||
assist_pipeline.PipelineEvent(
|
assist_pipeline.PipelineEvent(
|
||||||
type=assist_pipeline.PipelineEventType.TTS_END,
|
type=assist_pipeline.PipelineEventType.TTS_END,
|
||||||
data={"tts_output": {"media_id": _MEDIA_ID}},
|
data={"tts_output": {"token": mock_tts_result_stream.token}},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user