Improve Voip pipeline stability (#137620)

* Improve Voip pipeline stability

It appears the pipeline is being unexpectedly cancelled in some
instances. In order to mitigate this issue hang ups will be detected
using a separate task rather than relying on timeouts in the STT read
method. Also reading STT events will be retried once if it is cancelled.
The pipeline will also catch and log any CancelledErrors to help with
further debugging.

* Update Voip tests

* Remove unnecessary changes

Remove unnecessary logging and cancelled error handling in wyoming STT.

* Remove comment about clearing system prompt

The test no longer checks for clearing the system prompt. Since that
logic exists completely in the assist_satellite component I think it is
reasonable to only test that logic in the unit tests for that component.

* Re-raise cancellation

Re-raise CancelledError if the current task is cancelling in the check hangup task

Co-authored-by: J. Nick Koston <nick@koston.org>

* Re-raise CancelledError in pipeline as well

* Fix formatting issue

* Remove unnecessary logging

* Add MockResultStream import to tests

This was presumably missed while merging

* Cancel check hangup task on disconnect

* Add myself as codeowner for VoIP

* Update CODEOWNERS

---------

Co-authored-by: J. Nick Koston <nick@koston.org>
Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
Jamin 2025-05-05 19:25:52 -05:00 committed by GitHub
parent f3b23afc92
commit 14f967cdd0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 171 additions and 125 deletions

4
CODEOWNERS generated
View File

@ -1678,8 +1678,8 @@ build.json @home-assistant/supervisor
/tests/components/vlc_telnet/ @rodripf @MartinHjelmare /tests/components/vlc_telnet/ @rodripf @MartinHjelmare
/homeassistant/components/vodafone_station/ @paoloantinori @chemelli74 /homeassistant/components/vodafone_station/ @paoloantinori @chemelli74
/tests/components/vodafone_station/ @paoloantinori @chemelli74 /tests/components/vodafone_station/ @paoloantinori @chemelli74
/homeassistant/components/voip/ @balloob @synesthesiam /homeassistant/components/voip/ @balloob @synesthesiam @jaminh
/tests/components/voip/ @balloob @synesthesiam /tests/components/voip/ @balloob @synesthesiam @jaminh
/homeassistant/components/volumio/ @OnFreund /homeassistant/components/volumio/ @OnFreund
/tests/components/volumio/ @OnFreund /tests/components/volumio/ @OnFreund
/homeassistant/components/volvooncall/ @molobrakos /homeassistant/components/volvooncall/ @molobrakos

View File

@ -51,9 +51,9 @@ if TYPE_CHECKING:
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
_PIPELINE_TIMEOUT_SEC: Final = 30 _PIPELINE_TIMEOUT_SEC: Final = 30
_HANGUP_SEC: Final = 0.5
_ANNOUNCEMENT_BEFORE_DELAY: Final = 0.5 _ANNOUNCEMENT_BEFORE_DELAY: Final = 0.5
_ANNOUNCEMENT_AFTER_DELAY: Final = 1.0 _ANNOUNCEMENT_AFTER_DELAY: Final = 1.0
_ANNOUNCEMENT_HANGUP_SEC: Final = 0.5
_ANNOUNCEMENT_RING_TIMEOUT: Final = 30 _ANNOUNCEMENT_RING_TIMEOUT: Final = 30
@ -132,9 +132,10 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
self._processing_tone_done = asyncio.Event() self._processing_tone_done = asyncio.Event()
self._announcement: AssistSatelliteAnnouncement | None = None self._announcement: AssistSatelliteAnnouncement | None = None
self._announcement_future: asyncio.Future[Any] = asyncio.Future()
self._announcment_start_time: float = 0.0 self._announcment_start_time: float = 0.0
self._check_announcement_ended_task: asyncio.Task | None = None self._check_announcement_pickup_task: asyncio.Task | None = None
self._check_hangup_task: asyncio.Task | None = None
self._call_end_future: asyncio.Future[Any] = asyncio.Future()
self._last_chunk_time: float | None = None self._last_chunk_time: float | None = None
self._rtp_port: int | None = None self._rtp_port: int | None = None
self._run_pipeline_after_announce: bool = False self._run_pipeline_after_announce: bool = False
@ -233,7 +234,7 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
translation_key="non_tts_announcement", translation_key="non_tts_announcement",
) )
self._announcement_future = asyncio.Future() self._call_end_future = asyncio.Future()
self._run_pipeline_after_announce = run_pipeline_after self._run_pipeline_after_announce = run_pipeline_after
if self._rtp_port is None: if self._rtp_port is None:
@ -274,53 +275,77 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
rtp_port=self._rtp_port, rtp_port=self._rtp_port,
) )
# Check if caller hung up or didn't pick up # Check if caller didn't pick up
self._check_announcement_ended_task = ( self._check_announcement_pickup_task = (
self.config_entry.async_create_background_task( self.config_entry.async_create_background_task(
self.hass, self.hass,
self._check_announcement_ended(), self._check_announcement_pickup(),
"voip_announcement_ended", "voip_announcement_pickup",
) )
) )
try: try:
await self._announcement_future await self._call_end_future
except TimeoutError: except TimeoutError:
# Stop ringing # Stop ringing
_LOGGER.debug("Caller did not pick up in time")
sip_protocol.cancel_call(call_info) sip_protocol.cancel_call(call_info)
raise raise
async def _check_announcement_ended(self) -> None: async def _check_announcement_pickup(self) -> None:
"""Continuously checks if an audio chunk was received within a time limit. """Continuously checks if an audio chunk was received within a time limit.
If not, the caller is presumed to have hung up and the announcement is ended. If not, the caller is presumed to have not picked up the phone and the announcement is ended.
""" """
while self._announcement is not None: while True:
current_time = time.monotonic() current_time = time.monotonic()
if (self._last_chunk_time is None) and ( if (self._last_chunk_time is None) and (
(current_time - self._announcment_start_time) (current_time - self._announcment_start_time)
> _ANNOUNCEMENT_RING_TIMEOUT > _ANNOUNCEMENT_RING_TIMEOUT
): ):
# Ring timeout # Ring timeout
_LOGGER.debug("Ring timeout")
self._announcement = None self._announcement = None
self._check_announcement_ended_task = None self._check_announcement_pickup_task = None
self._announcement_future.set_exception( self._call_end_future.set_exception(
TimeoutError("User did not pick up in time") TimeoutError("User did not pick up in time")
) )
_LOGGER.debug("Timed out waiting for the user to pick up the phone") _LOGGER.debug("Timed out waiting for the user to pick up the phone")
break break
if self._last_chunk_time is not None:
if (self._last_chunk_time is not None) and ( _LOGGER.debug("Picked up the phone")
(current_time - self._last_chunk_time) > _ANNOUNCEMENT_HANGUP_SEC self._check_announcement_pickup_task = None
):
# Caller hung up
self._announcement = None
self._announcement_future.set_result(None)
self._check_announcement_ended_task = None
_LOGGER.debug("Announcement ended")
break break
await asyncio.sleep(_ANNOUNCEMENT_HANGUP_SEC / 2) await asyncio.sleep(_HANGUP_SEC / 2)
async def _check_hangup(self) -> None:
"""Continuously checks if an audio chunk was received within a time limit.
If not, the caller is presumed to have hung up and the call is ended.
"""
try:
while True:
current_time = time.monotonic()
if (self._last_chunk_time is not None) and (
(current_time - self._last_chunk_time) > _HANGUP_SEC
):
# Caller hung up
_LOGGER.debug("Hang up")
self._announcement = None
if self._run_pipeline_task is not None:
_LOGGER.debug("Cancelling running pipeline")
self._run_pipeline_task.cancel()
self._call_end_future.set_result(None)
self.disconnect()
break
await asyncio.sleep(_HANGUP_SEC / 2)
except asyncio.CancelledError:
# Don't swallow cancellation
if (current_task := asyncio.current_task()) and current_task.cancelling():
raise
_LOGGER.debug("Check hangup cancelled")
async def async_start_conversation( async def async_start_conversation(
self, start_announcement: AssistSatelliteAnnouncement self, start_announcement: AssistSatelliteAnnouncement
@ -332,6 +357,24 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
# VoIP # VoIP
# ------------------------------------------------------------------------- # -------------------------------------------------------------------------
def disconnect(self):
"""Server disconnected."""
super().disconnect()
if self._check_hangup_task is not None:
self._check_hangup_task.cancel()
self._check_hangup_task = None
def connection_made(self, transport):
"""Server is ready."""
super().connection_made(transport)
self._last_chunk_time = time.monotonic()
# Check if caller hung up
self._check_hangup_task = self.config_entry.async_create_background_task(
self.hass,
self._check_hangup(),
"voip_hangup",
)
def on_chunk(self, audio_bytes: bytes) -> None: def on_chunk(self, audio_bytes: bytes) -> None:
"""Handle raw audio chunk.""" """Handle raw audio chunk."""
self._last_chunk_time = time.monotonic() self._last_chunk_time = time.monotonic()
@ -368,13 +411,22 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
self.voip_device.set_is_active(True) self.voip_device.set_is_active(True)
async def stt_stream(): async def stt_stream():
retry: bool = True
while True: while True:
async with asyncio.timeout(self._audio_chunk_timeout): try:
chunk = await self._audio_queue.get() async with asyncio.timeout(self._audio_chunk_timeout):
if not chunk: chunk = await self._audio_queue.get()
break if not chunk:
_LOGGER.debug("STT stream got None")
break
yield chunk yield chunk
except TimeoutError:
_LOGGER.debug("STT Stream timed out")
if not retry:
_LOGGER.debug("No more retries, ending STT stream")
break
retry = False
# Play listening tone at the start of each cycle # Play listening tone at the start of each cycle
await self._play_tone(Tones.LISTENING, silence_before=0.2) await self._play_tone(Tones.LISTENING, silence_before=0.2)
@ -385,6 +437,7 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
) )
if self._pipeline_had_error: if self._pipeline_had_error:
_LOGGER.debug("Pipeline error")
self._pipeline_had_error = False self._pipeline_had_error = False
await self._play_tone(Tones.ERROR) await self._play_tone(Tones.ERROR)
else: else:
@ -394,7 +447,14 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
# length of the TTS audio. # length of the TTS audio.
await self._tts_done.wait() await self._tts_done.wait()
except TimeoutError: except TimeoutError:
# This shouldn't happen anymore, we are detecting hang ups with a separate task
_LOGGER.exception("Timeout error")
self.disconnect() # caller hung up self.disconnect() # caller hung up
except asyncio.CancelledError:
_LOGGER.debug("Pipeline cancelled")
# Don't swallow cancellation
if (current_task := asyncio.current_task()) and current_task.cancelling():
raise
finally: finally:
# Stop audio stream # Stop audio stream
await self._audio_queue.put(None) await self._audio_queue.put(None)
@ -433,8 +493,8 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
if self._run_pipeline_after_announce: if self._run_pipeline_after_announce:
# Clear announcement to allow pipeline to run # Clear announcement to allow pipeline to run
_LOGGER.debug("Clearing announcement")
self._announcement = None self._announcement = None
self._announcement_future.set_result(None)
def _clear_audio_queue(self) -> None: def _clear_audio_queue(self) -> None:
"""Ensure audio queue is empty.""" """Ensure audio queue is empty."""
@ -463,6 +523,7 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
) )
else: else:
# Empty TTS response # Empty TTS response
_LOGGER.debug("Empty TTS response")
self._tts_done.set() self._tts_done.set()
elif event.type == PipelineEventType.ERROR: elif event.type == PipelineEventType.ERROR:
# Play error tone instead of wait for TTS when pipeline is finished. # Play error tone instead of wait for TTS when pipeline is finished.

View File

@ -1,7 +1,7 @@
{ {
"domain": "voip", "domain": "voip",
"name": "Voice over IP", "name": "Voice over IP",
"codeowners": ["@balloob", "@synesthesiam"], "codeowners": ["@balloob", "@synesthesiam", "@jaminh"],
"config_flow": true, "config_flow": true,
"dependencies": ["assist_pipeline", "assist_satellite", "intent", "network"], "dependencies": ["assist_pipeline", "assist_satellite", "intent", "network"],
"documentation": "https://www.home-assistant.io/integrations/voip", "documentation": "https://www.home-assistant.io/integrations/voip",

View File

@ -335,9 +335,8 @@ async def test_pipeline(
patch.object(satellite, "tts_response_finished", tts_response_finished), patch.object(satellite, "tts_response_finished", tts_response_finished),
): ):
satellite._tones = Tones(0) satellite._tones = Tones(0)
satellite.transport = Mock() satellite.connection_made(Mock())
satellite.connection_made(satellite.transport)
assert satellite.state == AssistSatelliteState.IDLE assert satellite.state == AssistSatelliteState.IDLE
# Ensure audio queue is cleared before pipeline starts # Ensure audio queue is cleared before pipeline starts
@ -473,7 +472,7 @@ async def test_tts_timeout(
for tone in Tones: for tone in Tones:
satellite._tone_bytes[tone] = tone_bytes satellite._tone_bytes[tone] = tone_bytes
satellite.transport = Mock() satellite.connection_made(Mock())
satellite.send_audio = Mock() satellite.send_audio = Mock()
original_send_tts = satellite._send_tts original_send_tts = satellite._send_tts
@ -511,6 +510,7 @@ async def test_tts_wrong_extension(
assert await async_setup_component(hass, "voip", {}) assert await async_setup_component(hass, "voip", {})
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id) satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
satellite.addr = ("192.168.1.1", 12345)
assert isinstance(satellite, VoipAssistSatellite) assert isinstance(satellite, VoipAssistSatellite)
done = asyncio.Event() done = asyncio.Event()
@ -559,8 +559,6 @@ async def test_tts_wrong_extension(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream", "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream, new=async_pipeline_from_audio_stream,
): ):
satellite.transport = Mock()
original_send_tts = satellite._send_tts original_send_tts = satellite._send_tts
async def send_tts(*args, **kwargs): async def send_tts(*args, **kwargs):
@ -572,6 +570,8 @@ async def test_tts_wrong_extension(
satellite._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign] satellite._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
satellite.connection_made(Mock())
# silence # silence
satellite.on_chunk(bytes(_ONE_SECOND)) satellite.on_chunk(bytes(_ONE_SECOND))
@ -579,10 +579,18 @@ async def test_tts_wrong_extension(
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2)) satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
# silence (assumes relaxed VAD sensitivity) # silence (assumes relaxed VAD sensitivity)
satellite.on_chunk(bytes(_ONE_SECOND * 4)) satellite.on_chunk(bytes(_ONE_SECOND))
await asyncio.sleep(0.2)
satellite.on_chunk(bytes(_ONE_SECOND))
await asyncio.sleep(0.2)
satellite.on_chunk(bytes(_ONE_SECOND))
await asyncio.sleep(0.2)
satellite.on_chunk(bytes(_ONE_SECOND))
await asyncio.sleep(0.2)
satellite.on_chunk(bytes(_ONE_SECOND))
# Wait for mock pipeline to exhaust the audio stream # Wait for mock pipeline to exhaust the audio stream
async with asyncio.timeout(1): async with asyncio.timeout(3):
await done.wait() await done.wait()
@ -595,6 +603,7 @@ async def test_tts_wrong_wav_format(
assert await async_setup_component(hass, "voip", {}) assert await async_setup_component(hass, "voip", {})
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id) satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
satellite.addr = ("192.168.1.1", 12345)
assert isinstance(satellite, VoipAssistSatellite) assert isinstance(satellite, VoipAssistSatellite)
done = asyncio.Event() done = asyncio.Event()
@ -643,8 +652,6 @@ async def test_tts_wrong_wav_format(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream", "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream, new=async_pipeline_from_audio_stream,
): ):
satellite.transport = Mock()
original_send_tts = satellite._send_tts original_send_tts = satellite._send_tts
async def send_tts(*args, **kwargs): async def send_tts(*args, **kwargs):
@ -656,6 +663,8 @@ async def test_tts_wrong_wav_format(
satellite._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign] satellite._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
satellite.connection_made(Mock())
# silence # silence
satellite.on_chunk(bytes(_ONE_SECOND)) satellite.on_chunk(bytes(_ONE_SECOND))
@ -663,10 +672,18 @@ async def test_tts_wrong_wav_format(
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2)) satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
# silence (assumes relaxed VAD sensitivity) # silence (assumes relaxed VAD sensitivity)
satellite.on_chunk(bytes(_ONE_SECOND * 4)) satellite.on_chunk(bytes(_ONE_SECOND))
await asyncio.sleep(0.2)
satellite.on_chunk(bytes(_ONE_SECOND))
await asyncio.sleep(0.2)
satellite.on_chunk(bytes(_ONE_SECOND))
await asyncio.sleep(0.2)
satellite.on_chunk(bytes(_ONE_SECOND))
await asyncio.sleep(0.2)
satellite.on_chunk(bytes(_ONE_SECOND))
# Wait for mock pipeline to exhaust the audio stream # Wait for mock pipeline to exhaust the audio stream
async with asyncio.timeout(1): async with asyncio.timeout(3):
await done.wait() await done.wait()
@ -679,6 +696,7 @@ async def test_empty_tts_output(
assert await async_setup_component(hass, "voip", {}) assert await async_setup_component(hass, "voip", {})
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id) satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
satellite.addr = ("192.168.1.1", 12345)
assert isinstance(satellite, VoipAssistSatellite) assert isinstance(satellite, VoipAssistSatellite)
async def async_pipeline_from_audio_stream(*args, **kwargs): async def async_pipeline_from_audio_stream(*args, **kwargs):
@ -728,7 +746,7 @@ async def test_empty_tts_output(
"homeassistant.components.voip.assist_satellite.VoipAssistSatellite._send_tts", "homeassistant.components.voip.assist_satellite.VoipAssistSatellite._send_tts",
) as mock_send_tts, ) as mock_send_tts,
): ):
satellite.transport = Mock() satellite.connection_made(Mock())
# silence # silence
satellite.on_chunk(bytes(_ONE_SECOND)) satellite.on_chunk(bytes(_ONE_SECOND))
@ -737,10 +755,18 @@ async def test_empty_tts_output(
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2)) satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
# silence (assumes relaxed VAD sensitivity) # silence (assumes relaxed VAD sensitivity)
satellite.on_chunk(bytes(_ONE_SECOND * 4)) satellite.on_chunk(bytes(_ONE_SECOND))
await asyncio.sleep(0.2)
satellite.on_chunk(bytes(_ONE_SECOND))
await asyncio.sleep(0.2)
satellite.on_chunk(bytes(_ONE_SECOND))
await asyncio.sleep(0.2)
satellite.on_chunk(bytes(_ONE_SECOND))
await asyncio.sleep(0.2)
satellite.on_chunk(bytes(_ONE_SECOND))
# Wait for mock pipeline to finish # Wait for mock pipeline to finish
async with asyncio.timeout(1): async with asyncio.timeout(2):
await satellite._tts_done.wait() await satellite._tts_done.wait()
mock_send_tts.assert_not_called() mock_send_tts.assert_not_called()
@ -785,7 +811,7 @@ async def test_pipeline_error(
), ),
): ):
satellite._tones = Tones.ERROR satellite._tones = Tones.ERROR
satellite.transport = Mock() satellite.connection_made(Mock())
satellite._async_send_audio = AsyncMock(side_effect=async_send_audio) # type: ignore[method-assign] satellite._async_send_audio = AsyncMock(side_effect=async_send_audio) # type: ignore[method-assign]
satellite.on_chunk(bytes(_ONE_SECOND)) satellite.on_chunk(bytes(_ONE_SECOND))
@ -845,16 +871,20 @@ async def test_announce(
"homeassistant.components.voip.assist_satellite.VoipAssistSatellite._send_tts", "homeassistant.components.voip.assist_satellite.VoipAssistSatellite._send_tts",
) as mock_send_tts, ) as mock_send_tts,
): ):
satellite.transport = Mock()
announce_task = hass.async_create_background_task( announce_task = hass.async_create_background_task(
satellite.async_announce(announcement), "voip_announce" satellite.async_announce(announcement), "voip_announce"
) )
await asyncio.sleep(0) await asyncio.sleep(0)
satellite.connection_made(Mock())
mock_protocol.outgoing_call.assert_called_once() mock_protocol.outgoing_call.assert_called_once()
# Trigger announcement # Trigger announcement
satellite.on_chunk(bytes(_ONE_SECOND)) satellite.on_chunk(bytes(_ONE_SECOND))
async with asyncio.timeout(1): await asyncio.sleep(0.2)
satellite.on_chunk(bytes(_ONE_SECOND))
await asyncio.sleep(0.2)
satellite.on_chunk(bytes(_ONE_SECOND))
async with asyncio.timeout(2):
await announce_task await announce_task
mock_send_tts.assert_called_once_with( mock_send_tts.assert_called_once_with(
@ -897,11 +927,11 @@ async def test_voip_id_is_ip_address(
"homeassistant.components.voip.assist_satellite.VoipAssistSatellite._send_tts", "homeassistant.components.voip.assist_satellite.VoipAssistSatellite._send_tts",
) as mock_send_tts, ) as mock_send_tts,
): ):
satellite.transport = Mock()
announce_task = hass.async_create_background_task( announce_task = hass.async_create_background_task(
satellite.async_announce(announcement), "voip_announce" satellite.async_announce(announcement), "voip_announce"
) )
await asyncio.sleep(0) await asyncio.sleep(0)
satellite.connection_made(Mock())
mock_protocol.outgoing_call.assert_called_once() mock_protocol.outgoing_call.assert_called_once()
assert ( assert (
mock_protocol.outgoing_call.call_args.kwargs["destination"].host mock_protocol.outgoing_call.call_args.kwargs["destination"].host
@ -910,7 +940,11 @@ async def test_voip_id_is_ip_address(
# Trigger announcement # Trigger announcement
satellite.on_chunk(bytes(_ONE_SECOND)) satellite.on_chunk(bytes(_ONE_SECOND))
async with asyncio.timeout(1): await asyncio.sleep(0.2)
satellite.on_chunk(bytes(_ONE_SECOND))
await asyncio.sleep(0.2)
satellite.on_chunk(bytes(_ONE_SECOND))
async with asyncio.timeout(2):
await announce_task await announce_task
mock_send_tts.assert_called_once_with( mock_send_tts.assert_called_once_with(
@ -955,7 +989,7 @@ async def test_announce_timeout(
0.01, 0.01,
), ),
): ):
satellite.transport = Mock() satellite.connection_made(Mock())
with pytest.raises(TimeoutError): with pytest.raises(TimeoutError):
await satellite.async_announce(announcement) await satellite.async_announce(announcement)
@ -1042,7 +1076,7 @@ async def test_start_conversation(
new=async_pipeline_from_audio_stream, new=async_pipeline_from_audio_stream,
), ),
): ):
satellite.transport = Mock() satellite.connection_made(Mock())
conversation_task = hass.async_create_background_task( conversation_task = hass.async_create_background_task(
satellite.async_start_conversation(announcement), "voip_start_conversation" satellite.async_start_conversation(announcement), "voip_start_conversation"
) )
@ -1051,16 +1085,20 @@ async def test_start_conversation(
# Trigger announcement and wait for it to finish # Trigger announcement and wait for it to finish
satellite.on_chunk(bytes(_ONE_SECOND)) satellite.on_chunk(bytes(_ONE_SECOND))
async with asyncio.timeout(1): await asyncio.sleep(0.2)
satellite.on_chunk(bytes(_ONE_SECOND))
await asyncio.sleep(0.2)
satellite.on_chunk(bytes(_ONE_SECOND))
async with asyncio.timeout(2):
await tts_sent.wait() await tts_sent.wait()
tts_sent.clear()
# Trigger pipeline # Trigger pipeline
satellite.on_chunk(bytes(_ONE_SECOND)) satellite.on_chunk(bytes(_ONE_SECOND))
async with asyncio.timeout(1): await asyncio.sleep(0.2)
# Wait for TTS satellite.on_chunk(bytes(_ONE_SECOND))
await tts_sent.wait() await asyncio.sleep(3)
async with asyncio.timeout(3):
# Wait for Conversation end
await conversation_task await conversation_task
@ -1073,21 +1111,8 @@ async def test_start_conversation_user_doesnt_pick_up(
"""Test start conversation when the user doesn't pick up.""" """Test start conversation when the user doesn't pick up."""
assert await async_setup_component(hass, "voip", {}) assert await async_setup_component(hass, "voip", {})
pipeline = assist_pipeline.Pipeline(
conversation_engine="test engine",
conversation_language="en",
language="en",
name="test pipeline",
stt_engine="test stt",
stt_language="en",
tts_engine="test tts",
tts_language="en",
tts_voice=None,
wake_word_entity=None,
wake_word_id=None,
)
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id) satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
satellite.addr = ("192.168.1.1", 12345)
assert isinstance(satellite, VoipAssistSatellite) assert isinstance(satellite, VoipAssistSatellite)
assert ( assert (
satellite.supported_features satellite.supported_features
@ -1098,62 +1123,22 @@ async def test_start_conversation_user_doesnt_pick_up(
mock_protocol: AsyncMock = hass.data[DOMAIN].protocol mock_protocol: AsyncMock = hass.data[DOMAIN].protocol
mock_protocol.outgoing_call = Mock() mock_protocol.outgoing_call = Mock()
pipeline_started = asyncio.Event() announcement = assist_satellite.AssistSatelliteAnnouncement(
message="test announcement",
async def async_pipeline_from_audio_stream( media_id=_MEDIA_ID,
hass: HomeAssistant, tts_token="test-token",
context: Context, original_media_id=_MEDIA_ID,
*args, media_id_source="tts",
conversation_extra_system_prompt: str | None = None, )
**kwargs,
):
# System prompt should be not be set due to timeout (user not picking up)
assert conversation_extra_system_prompt is None
pipeline_started.set()
# Very short timeout which will trigger because we don't send any audio in
with ( with (
patch( patch(
"homeassistant.components.assist_satellite.entity.async_get_pipeline", "homeassistant.components.voip.assist_satellite._ANNOUNCEMENT_RING_TIMEOUT",
return_value=pipeline, 0.1,
),
patch(
"homeassistant.components.voip.assist_satellite.VoipAssistSatellite.async_start_conversation",
side_effect=TimeoutError,
),
patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch(
"homeassistant.components.tts.generate_media_source_id",
return_value="media-source://bla",
),
patch(
"homeassistant.components.tts.async_resolve_engine",
return_value="test tts",
),
patch(
"homeassistant.components.tts.async_create_stream",
return_value=MockResultStream(hass, "wav", b""),
), ),
): ):
satellite.transport = Mock() satellite.connection_made(Mock())
# Error should clear system prompt
with pytest.raises(TimeoutError): with pytest.raises(TimeoutError):
await hass.services.async_call( await satellite.async_start_conversation(announcement)
assist_satellite.DOMAIN,
"start_conversation",
{
"entity_id": satellite.entity_id,
"start_message": "test announcement",
"extra_system_prompt": "test prompt",
},
blocking=True,
)
# Trigger a pipeline so we can check if the system prompt was cleared
satellite.on_chunk(bytes(_ONE_SECOND))
async with asyncio.timeout(1):
await pipeline_started.wait()