mirror of
https://github.com/home-assistant/core.git
synced 2025-07-15 17:27:10 +00:00
Improve Voip pipeline stability (#137620)
* Improve Voip pipeline stability It appears the pipeline is being unexpectedly cancelled in some instances. In order to mitigate this issue hang ups will be detected using a separate task rather than relying on timeouts in the STT read method. Also reading STT events will be retried once if it is cancelled. The pipeline will also catch and log any CancelledErrors to help with further debugging. * Update Voip tests * Remove unnecessary changes Remove unnecessary logging and cancelled error handling in wyoming STT. * Remove comment about clearing system prompt The test no longer checks for clearing the system prompt. Since that logic exists completely in the assist_satellite component I think it is reasonable to only test that logic in the unit tests for that component. * Re-raise cancellation Re-raise CancelledError if the current task is cancelling in the check hangup task Co-authored-by: J. Nick Koston <nick@koston.org> * Re-raise CancelledError in pipeline as well * Fix formatting issue * Remove unnecessary logging * Add MockResultStream import to tests This was presumably missed while merging * Cancel check hangup task on disconnect * Add myself as codeowner for VoIP * Update CODEOWNERS --------- Co-authored-by: J. Nick Koston <nick@koston.org> Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
This commit is contained in:
parent
f3b23afc92
commit
14f967cdd0
4
CODEOWNERS
generated
4
CODEOWNERS
generated
@ -1678,8 +1678,8 @@ build.json @home-assistant/supervisor
|
|||||||
/tests/components/vlc_telnet/ @rodripf @MartinHjelmare
|
/tests/components/vlc_telnet/ @rodripf @MartinHjelmare
|
||||||
/homeassistant/components/vodafone_station/ @paoloantinori @chemelli74
|
/homeassistant/components/vodafone_station/ @paoloantinori @chemelli74
|
||||||
/tests/components/vodafone_station/ @paoloantinori @chemelli74
|
/tests/components/vodafone_station/ @paoloantinori @chemelli74
|
||||||
/homeassistant/components/voip/ @balloob @synesthesiam
|
/homeassistant/components/voip/ @balloob @synesthesiam @jaminh
|
||||||
/tests/components/voip/ @balloob @synesthesiam
|
/tests/components/voip/ @balloob @synesthesiam @jaminh
|
||||||
/homeassistant/components/volumio/ @OnFreund
|
/homeassistant/components/volumio/ @OnFreund
|
||||||
/tests/components/volumio/ @OnFreund
|
/tests/components/volumio/ @OnFreund
|
||||||
/homeassistant/components/volvooncall/ @molobrakos
|
/homeassistant/components/volvooncall/ @molobrakos
|
||||||
|
@ -51,9 +51,9 @@ if TYPE_CHECKING:
|
|||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
_PIPELINE_TIMEOUT_SEC: Final = 30
|
_PIPELINE_TIMEOUT_SEC: Final = 30
|
||||||
|
_HANGUP_SEC: Final = 0.5
|
||||||
_ANNOUNCEMENT_BEFORE_DELAY: Final = 0.5
|
_ANNOUNCEMENT_BEFORE_DELAY: Final = 0.5
|
||||||
_ANNOUNCEMENT_AFTER_DELAY: Final = 1.0
|
_ANNOUNCEMENT_AFTER_DELAY: Final = 1.0
|
||||||
_ANNOUNCEMENT_HANGUP_SEC: Final = 0.5
|
|
||||||
_ANNOUNCEMENT_RING_TIMEOUT: Final = 30
|
_ANNOUNCEMENT_RING_TIMEOUT: Final = 30
|
||||||
|
|
||||||
|
|
||||||
@ -132,9 +132,10 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
|
|||||||
self._processing_tone_done = asyncio.Event()
|
self._processing_tone_done = asyncio.Event()
|
||||||
|
|
||||||
self._announcement: AssistSatelliteAnnouncement | None = None
|
self._announcement: AssistSatelliteAnnouncement | None = None
|
||||||
self._announcement_future: asyncio.Future[Any] = asyncio.Future()
|
|
||||||
self._announcment_start_time: float = 0.0
|
self._announcment_start_time: float = 0.0
|
||||||
self._check_announcement_ended_task: asyncio.Task | None = None
|
self._check_announcement_pickup_task: asyncio.Task | None = None
|
||||||
|
self._check_hangup_task: asyncio.Task | None = None
|
||||||
|
self._call_end_future: asyncio.Future[Any] = asyncio.Future()
|
||||||
self._last_chunk_time: float | None = None
|
self._last_chunk_time: float | None = None
|
||||||
self._rtp_port: int | None = None
|
self._rtp_port: int | None = None
|
||||||
self._run_pipeline_after_announce: bool = False
|
self._run_pipeline_after_announce: bool = False
|
||||||
@ -233,7 +234,7 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
|
|||||||
translation_key="non_tts_announcement",
|
translation_key="non_tts_announcement",
|
||||||
)
|
)
|
||||||
|
|
||||||
self._announcement_future = asyncio.Future()
|
self._call_end_future = asyncio.Future()
|
||||||
self._run_pipeline_after_announce = run_pipeline_after
|
self._run_pipeline_after_announce = run_pipeline_after
|
||||||
|
|
||||||
if self._rtp_port is None:
|
if self._rtp_port is None:
|
||||||
@ -274,53 +275,77 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
|
|||||||
rtp_port=self._rtp_port,
|
rtp_port=self._rtp_port,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if caller hung up or didn't pick up
|
# Check if caller didn't pick up
|
||||||
self._check_announcement_ended_task = (
|
self._check_announcement_pickup_task = (
|
||||||
self.config_entry.async_create_background_task(
|
self.config_entry.async_create_background_task(
|
||||||
self.hass,
|
self.hass,
|
||||||
self._check_announcement_ended(),
|
self._check_announcement_pickup(),
|
||||||
"voip_announcement_ended",
|
"voip_announcement_pickup",
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self._announcement_future
|
await self._call_end_future
|
||||||
except TimeoutError:
|
except TimeoutError:
|
||||||
# Stop ringing
|
# Stop ringing
|
||||||
|
_LOGGER.debug("Caller did not pick up in time")
|
||||||
sip_protocol.cancel_call(call_info)
|
sip_protocol.cancel_call(call_info)
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def _check_announcement_ended(self) -> None:
|
async def _check_announcement_pickup(self) -> None:
|
||||||
"""Continuously checks if an audio chunk was received within a time limit.
|
"""Continuously checks if an audio chunk was received within a time limit.
|
||||||
|
|
||||||
If not, the caller is presumed to have hung up and the announcement is ended.
|
If not, the caller is presumed to have not picked up the phone and the announcement is ended.
|
||||||
"""
|
"""
|
||||||
while self._announcement is not None:
|
while True:
|
||||||
current_time = time.monotonic()
|
current_time = time.monotonic()
|
||||||
if (self._last_chunk_time is None) and (
|
if (self._last_chunk_time is None) and (
|
||||||
(current_time - self._announcment_start_time)
|
(current_time - self._announcment_start_time)
|
||||||
> _ANNOUNCEMENT_RING_TIMEOUT
|
> _ANNOUNCEMENT_RING_TIMEOUT
|
||||||
):
|
):
|
||||||
# Ring timeout
|
# Ring timeout
|
||||||
|
_LOGGER.debug("Ring timeout")
|
||||||
self._announcement = None
|
self._announcement = None
|
||||||
self._check_announcement_ended_task = None
|
self._check_announcement_pickup_task = None
|
||||||
self._announcement_future.set_exception(
|
self._call_end_future.set_exception(
|
||||||
TimeoutError("User did not pick up in time")
|
TimeoutError("User did not pick up in time")
|
||||||
)
|
)
|
||||||
_LOGGER.debug("Timed out waiting for the user to pick up the phone")
|
_LOGGER.debug("Timed out waiting for the user to pick up the phone")
|
||||||
break
|
break
|
||||||
|
if self._last_chunk_time is not None:
|
||||||
if (self._last_chunk_time is not None) and (
|
_LOGGER.debug("Picked up the phone")
|
||||||
(current_time - self._last_chunk_time) > _ANNOUNCEMENT_HANGUP_SEC
|
self._check_announcement_pickup_task = None
|
||||||
):
|
|
||||||
# Caller hung up
|
|
||||||
self._announcement = None
|
|
||||||
self._announcement_future.set_result(None)
|
|
||||||
self._check_announcement_ended_task = None
|
|
||||||
_LOGGER.debug("Announcement ended")
|
|
||||||
break
|
break
|
||||||
|
|
||||||
await asyncio.sleep(_ANNOUNCEMENT_HANGUP_SEC / 2)
|
await asyncio.sleep(_HANGUP_SEC / 2)
|
||||||
|
|
||||||
|
async def _check_hangup(self) -> None:
|
||||||
|
"""Continuously checks if an audio chunk was received within a time limit.
|
||||||
|
|
||||||
|
If not, the caller is presumed to have hung up and the call is ended.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
current_time = time.monotonic()
|
||||||
|
if (self._last_chunk_time is not None) and (
|
||||||
|
(current_time - self._last_chunk_time) > _HANGUP_SEC
|
||||||
|
):
|
||||||
|
# Caller hung up
|
||||||
|
_LOGGER.debug("Hang up")
|
||||||
|
self._announcement = None
|
||||||
|
if self._run_pipeline_task is not None:
|
||||||
|
_LOGGER.debug("Cancelling running pipeline")
|
||||||
|
self._run_pipeline_task.cancel()
|
||||||
|
self._call_end_future.set_result(None)
|
||||||
|
self.disconnect()
|
||||||
|
break
|
||||||
|
|
||||||
|
await asyncio.sleep(_HANGUP_SEC / 2)
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
# Don't swallow cancellation
|
||||||
|
if (current_task := asyncio.current_task()) and current_task.cancelling():
|
||||||
|
raise
|
||||||
|
_LOGGER.debug("Check hangup cancelled")
|
||||||
|
|
||||||
async def async_start_conversation(
|
async def async_start_conversation(
|
||||||
self, start_announcement: AssistSatelliteAnnouncement
|
self, start_announcement: AssistSatelliteAnnouncement
|
||||||
@ -332,6 +357,24 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
|
|||||||
# VoIP
|
# VoIP
|
||||||
# -------------------------------------------------------------------------
|
# -------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def disconnect(self):
|
||||||
|
"""Server disconnected."""
|
||||||
|
super().disconnect()
|
||||||
|
if self._check_hangup_task is not None:
|
||||||
|
self._check_hangup_task.cancel()
|
||||||
|
self._check_hangup_task = None
|
||||||
|
|
||||||
|
def connection_made(self, transport):
|
||||||
|
"""Server is ready."""
|
||||||
|
super().connection_made(transport)
|
||||||
|
self._last_chunk_time = time.monotonic()
|
||||||
|
# Check if caller hung up
|
||||||
|
self._check_hangup_task = self.config_entry.async_create_background_task(
|
||||||
|
self.hass,
|
||||||
|
self._check_hangup(),
|
||||||
|
"voip_hangup",
|
||||||
|
)
|
||||||
|
|
||||||
def on_chunk(self, audio_bytes: bytes) -> None:
|
def on_chunk(self, audio_bytes: bytes) -> None:
|
||||||
"""Handle raw audio chunk."""
|
"""Handle raw audio chunk."""
|
||||||
self._last_chunk_time = time.monotonic()
|
self._last_chunk_time = time.monotonic()
|
||||||
@ -368,13 +411,22 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
|
|||||||
self.voip_device.set_is_active(True)
|
self.voip_device.set_is_active(True)
|
||||||
|
|
||||||
async def stt_stream():
|
async def stt_stream():
|
||||||
|
retry: bool = True
|
||||||
while True:
|
while True:
|
||||||
async with asyncio.timeout(self._audio_chunk_timeout):
|
try:
|
||||||
chunk = await self._audio_queue.get()
|
async with asyncio.timeout(self._audio_chunk_timeout):
|
||||||
if not chunk:
|
chunk = await self._audio_queue.get()
|
||||||
break
|
if not chunk:
|
||||||
|
_LOGGER.debug("STT stream got None")
|
||||||
|
break
|
||||||
|
|
||||||
yield chunk
|
yield chunk
|
||||||
|
except TimeoutError:
|
||||||
|
_LOGGER.debug("STT Stream timed out")
|
||||||
|
if not retry:
|
||||||
|
_LOGGER.debug("No more retries, ending STT stream")
|
||||||
|
break
|
||||||
|
retry = False
|
||||||
|
|
||||||
# Play listening tone at the start of each cycle
|
# Play listening tone at the start of each cycle
|
||||||
await self._play_tone(Tones.LISTENING, silence_before=0.2)
|
await self._play_tone(Tones.LISTENING, silence_before=0.2)
|
||||||
@ -385,6 +437,7 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
|
|||||||
)
|
)
|
||||||
|
|
||||||
if self._pipeline_had_error:
|
if self._pipeline_had_error:
|
||||||
|
_LOGGER.debug("Pipeline error")
|
||||||
self._pipeline_had_error = False
|
self._pipeline_had_error = False
|
||||||
await self._play_tone(Tones.ERROR)
|
await self._play_tone(Tones.ERROR)
|
||||||
else:
|
else:
|
||||||
@ -394,7 +447,14 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
|
|||||||
# length of the TTS audio.
|
# length of the TTS audio.
|
||||||
await self._tts_done.wait()
|
await self._tts_done.wait()
|
||||||
except TimeoutError:
|
except TimeoutError:
|
||||||
|
# This shouldn't happen anymore, we are detecting hang ups with a separate task
|
||||||
|
_LOGGER.exception("Timeout error")
|
||||||
self.disconnect() # caller hung up
|
self.disconnect() # caller hung up
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
_LOGGER.debug("Pipeline cancelled")
|
||||||
|
# Don't swallow cancellation
|
||||||
|
if (current_task := asyncio.current_task()) and current_task.cancelling():
|
||||||
|
raise
|
||||||
finally:
|
finally:
|
||||||
# Stop audio stream
|
# Stop audio stream
|
||||||
await self._audio_queue.put(None)
|
await self._audio_queue.put(None)
|
||||||
@ -433,8 +493,8 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
|
|||||||
|
|
||||||
if self._run_pipeline_after_announce:
|
if self._run_pipeline_after_announce:
|
||||||
# Clear announcement to allow pipeline to run
|
# Clear announcement to allow pipeline to run
|
||||||
|
_LOGGER.debug("Clearing announcement")
|
||||||
self._announcement = None
|
self._announcement = None
|
||||||
self._announcement_future.set_result(None)
|
|
||||||
|
|
||||||
def _clear_audio_queue(self) -> None:
|
def _clear_audio_queue(self) -> None:
|
||||||
"""Ensure audio queue is empty."""
|
"""Ensure audio queue is empty."""
|
||||||
@ -463,6 +523,7 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
|
|||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Empty TTS response
|
# Empty TTS response
|
||||||
|
_LOGGER.debug("Empty TTS response")
|
||||||
self._tts_done.set()
|
self._tts_done.set()
|
||||||
elif event.type == PipelineEventType.ERROR:
|
elif event.type == PipelineEventType.ERROR:
|
||||||
# Play error tone instead of wait for TTS when pipeline is finished.
|
# Play error tone instead of wait for TTS when pipeline is finished.
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
{
|
{
|
||||||
"domain": "voip",
|
"domain": "voip",
|
||||||
"name": "Voice over IP",
|
"name": "Voice over IP",
|
||||||
"codeowners": ["@balloob", "@synesthesiam"],
|
"codeowners": ["@balloob", "@synesthesiam", "@jaminh"],
|
||||||
"config_flow": true,
|
"config_flow": true,
|
||||||
"dependencies": ["assist_pipeline", "assist_satellite", "intent", "network"],
|
"dependencies": ["assist_pipeline", "assist_satellite", "intent", "network"],
|
||||||
"documentation": "https://www.home-assistant.io/integrations/voip",
|
"documentation": "https://www.home-assistant.io/integrations/voip",
|
||||||
|
@ -335,9 +335,8 @@ async def test_pipeline(
|
|||||||
patch.object(satellite, "tts_response_finished", tts_response_finished),
|
patch.object(satellite, "tts_response_finished", tts_response_finished),
|
||||||
):
|
):
|
||||||
satellite._tones = Tones(0)
|
satellite._tones = Tones(0)
|
||||||
satellite.transport = Mock()
|
satellite.connection_made(Mock())
|
||||||
|
|
||||||
satellite.connection_made(satellite.transport)
|
|
||||||
assert satellite.state == AssistSatelliteState.IDLE
|
assert satellite.state == AssistSatelliteState.IDLE
|
||||||
|
|
||||||
# Ensure audio queue is cleared before pipeline starts
|
# Ensure audio queue is cleared before pipeline starts
|
||||||
@ -473,7 +472,7 @@ async def test_tts_timeout(
|
|||||||
for tone in Tones:
|
for tone in Tones:
|
||||||
satellite._tone_bytes[tone] = tone_bytes
|
satellite._tone_bytes[tone] = tone_bytes
|
||||||
|
|
||||||
satellite.transport = Mock()
|
satellite.connection_made(Mock())
|
||||||
satellite.send_audio = Mock()
|
satellite.send_audio = Mock()
|
||||||
|
|
||||||
original_send_tts = satellite._send_tts
|
original_send_tts = satellite._send_tts
|
||||||
@ -511,6 +510,7 @@ async def test_tts_wrong_extension(
|
|||||||
assert await async_setup_component(hass, "voip", {})
|
assert await async_setup_component(hass, "voip", {})
|
||||||
|
|
||||||
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
||||||
|
satellite.addr = ("192.168.1.1", 12345)
|
||||||
assert isinstance(satellite, VoipAssistSatellite)
|
assert isinstance(satellite, VoipAssistSatellite)
|
||||||
|
|
||||||
done = asyncio.Event()
|
done = asyncio.Event()
|
||||||
@ -559,8 +559,6 @@ async def test_tts_wrong_extension(
|
|||||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||||
new=async_pipeline_from_audio_stream,
|
new=async_pipeline_from_audio_stream,
|
||||||
):
|
):
|
||||||
satellite.transport = Mock()
|
|
||||||
|
|
||||||
original_send_tts = satellite._send_tts
|
original_send_tts = satellite._send_tts
|
||||||
|
|
||||||
async def send_tts(*args, **kwargs):
|
async def send_tts(*args, **kwargs):
|
||||||
@ -572,6 +570,8 @@ async def test_tts_wrong_extension(
|
|||||||
|
|
||||||
satellite._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
|
satellite._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
|
||||||
|
|
||||||
|
satellite.connection_made(Mock())
|
||||||
|
|
||||||
# silence
|
# silence
|
||||||
satellite.on_chunk(bytes(_ONE_SECOND))
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||||
|
|
||||||
@ -579,10 +579,18 @@ async def test_tts_wrong_extension(
|
|||||||
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||||
|
|
||||||
# silence (assumes relaxed VAD sensitivity)
|
# silence (assumes relaxed VAD sensitivity)
|
||||||
satellite.on_chunk(bytes(_ONE_SECOND * 4))
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||||
|
await asyncio.sleep(0.2)
|
||||||
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||||
|
await asyncio.sleep(0.2)
|
||||||
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||||
|
await asyncio.sleep(0.2)
|
||||||
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||||
|
await asyncio.sleep(0.2)
|
||||||
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||||
|
|
||||||
# Wait for mock pipeline to exhaust the audio stream
|
# Wait for mock pipeline to exhaust the audio stream
|
||||||
async with asyncio.timeout(1):
|
async with asyncio.timeout(3):
|
||||||
await done.wait()
|
await done.wait()
|
||||||
|
|
||||||
|
|
||||||
@ -595,6 +603,7 @@ async def test_tts_wrong_wav_format(
|
|||||||
assert await async_setup_component(hass, "voip", {})
|
assert await async_setup_component(hass, "voip", {})
|
||||||
|
|
||||||
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
||||||
|
satellite.addr = ("192.168.1.1", 12345)
|
||||||
assert isinstance(satellite, VoipAssistSatellite)
|
assert isinstance(satellite, VoipAssistSatellite)
|
||||||
|
|
||||||
done = asyncio.Event()
|
done = asyncio.Event()
|
||||||
@ -643,8 +652,6 @@ async def test_tts_wrong_wav_format(
|
|||||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
||||||
new=async_pipeline_from_audio_stream,
|
new=async_pipeline_from_audio_stream,
|
||||||
):
|
):
|
||||||
satellite.transport = Mock()
|
|
||||||
|
|
||||||
original_send_tts = satellite._send_tts
|
original_send_tts = satellite._send_tts
|
||||||
|
|
||||||
async def send_tts(*args, **kwargs):
|
async def send_tts(*args, **kwargs):
|
||||||
@ -656,6 +663,8 @@ async def test_tts_wrong_wav_format(
|
|||||||
|
|
||||||
satellite._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
|
satellite._send_tts = AsyncMock(side_effect=send_tts) # type: ignore[method-assign]
|
||||||
|
|
||||||
|
satellite.connection_made(Mock())
|
||||||
|
|
||||||
# silence
|
# silence
|
||||||
satellite.on_chunk(bytes(_ONE_SECOND))
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||||
|
|
||||||
@ -663,10 +672,18 @@ async def test_tts_wrong_wav_format(
|
|||||||
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||||
|
|
||||||
# silence (assumes relaxed VAD sensitivity)
|
# silence (assumes relaxed VAD sensitivity)
|
||||||
satellite.on_chunk(bytes(_ONE_SECOND * 4))
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||||
|
await asyncio.sleep(0.2)
|
||||||
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||||
|
await asyncio.sleep(0.2)
|
||||||
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||||
|
await asyncio.sleep(0.2)
|
||||||
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||||
|
await asyncio.sleep(0.2)
|
||||||
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||||
|
|
||||||
# Wait for mock pipeline to exhaust the audio stream
|
# Wait for mock pipeline to exhaust the audio stream
|
||||||
async with asyncio.timeout(1):
|
async with asyncio.timeout(3):
|
||||||
await done.wait()
|
await done.wait()
|
||||||
|
|
||||||
|
|
||||||
@ -679,6 +696,7 @@ async def test_empty_tts_output(
|
|||||||
assert await async_setup_component(hass, "voip", {})
|
assert await async_setup_component(hass, "voip", {})
|
||||||
|
|
||||||
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
||||||
|
satellite.addr = ("192.168.1.1", 12345)
|
||||||
assert isinstance(satellite, VoipAssistSatellite)
|
assert isinstance(satellite, VoipAssistSatellite)
|
||||||
|
|
||||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||||
@ -728,7 +746,7 @@ async def test_empty_tts_output(
|
|||||||
"homeassistant.components.voip.assist_satellite.VoipAssistSatellite._send_tts",
|
"homeassistant.components.voip.assist_satellite.VoipAssistSatellite._send_tts",
|
||||||
) as mock_send_tts,
|
) as mock_send_tts,
|
||||||
):
|
):
|
||||||
satellite.transport = Mock()
|
satellite.connection_made(Mock())
|
||||||
|
|
||||||
# silence
|
# silence
|
||||||
satellite.on_chunk(bytes(_ONE_SECOND))
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||||
@ -737,10 +755,18 @@ async def test_empty_tts_output(
|
|||||||
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
satellite.on_chunk(bytes([255] * _ONE_SECOND * 2))
|
||||||
|
|
||||||
# silence (assumes relaxed VAD sensitivity)
|
# silence (assumes relaxed VAD sensitivity)
|
||||||
satellite.on_chunk(bytes(_ONE_SECOND * 4))
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||||
|
await asyncio.sleep(0.2)
|
||||||
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||||
|
await asyncio.sleep(0.2)
|
||||||
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||||
|
await asyncio.sleep(0.2)
|
||||||
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||||
|
await asyncio.sleep(0.2)
|
||||||
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||||
|
|
||||||
# Wait for mock pipeline to finish
|
# Wait for mock pipeline to finish
|
||||||
async with asyncio.timeout(1):
|
async with asyncio.timeout(2):
|
||||||
await satellite._tts_done.wait()
|
await satellite._tts_done.wait()
|
||||||
|
|
||||||
mock_send_tts.assert_not_called()
|
mock_send_tts.assert_not_called()
|
||||||
@ -785,7 +811,7 @@ async def test_pipeline_error(
|
|||||||
),
|
),
|
||||||
):
|
):
|
||||||
satellite._tones = Tones.ERROR
|
satellite._tones = Tones.ERROR
|
||||||
satellite.transport = Mock()
|
satellite.connection_made(Mock())
|
||||||
satellite._async_send_audio = AsyncMock(side_effect=async_send_audio) # type: ignore[method-assign]
|
satellite._async_send_audio = AsyncMock(side_effect=async_send_audio) # type: ignore[method-assign]
|
||||||
|
|
||||||
satellite.on_chunk(bytes(_ONE_SECOND))
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||||
@ -845,16 +871,20 @@ async def test_announce(
|
|||||||
"homeassistant.components.voip.assist_satellite.VoipAssistSatellite._send_tts",
|
"homeassistant.components.voip.assist_satellite.VoipAssistSatellite._send_tts",
|
||||||
) as mock_send_tts,
|
) as mock_send_tts,
|
||||||
):
|
):
|
||||||
satellite.transport = Mock()
|
|
||||||
announce_task = hass.async_create_background_task(
|
announce_task = hass.async_create_background_task(
|
||||||
satellite.async_announce(announcement), "voip_announce"
|
satellite.async_announce(announcement), "voip_announce"
|
||||||
)
|
)
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
|
satellite.connection_made(Mock())
|
||||||
mock_protocol.outgoing_call.assert_called_once()
|
mock_protocol.outgoing_call.assert_called_once()
|
||||||
|
|
||||||
# Trigger announcement
|
# Trigger announcement
|
||||||
satellite.on_chunk(bytes(_ONE_SECOND))
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||||
async with asyncio.timeout(1):
|
await asyncio.sleep(0.2)
|
||||||
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||||
|
await asyncio.sleep(0.2)
|
||||||
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||||
|
async with asyncio.timeout(2):
|
||||||
await announce_task
|
await announce_task
|
||||||
|
|
||||||
mock_send_tts.assert_called_once_with(
|
mock_send_tts.assert_called_once_with(
|
||||||
@ -897,11 +927,11 @@ async def test_voip_id_is_ip_address(
|
|||||||
"homeassistant.components.voip.assist_satellite.VoipAssistSatellite._send_tts",
|
"homeassistant.components.voip.assist_satellite.VoipAssistSatellite._send_tts",
|
||||||
) as mock_send_tts,
|
) as mock_send_tts,
|
||||||
):
|
):
|
||||||
satellite.transport = Mock()
|
|
||||||
announce_task = hass.async_create_background_task(
|
announce_task = hass.async_create_background_task(
|
||||||
satellite.async_announce(announcement), "voip_announce"
|
satellite.async_announce(announcement), "voip_announce"
|
||||||
)
|
)
|
||||||
await asyncio.sleep(0)
|
await asyncio.sleep(0)
|
||||||
|
satellite.connection_made(Mock())
|
||||||
mock_protocol.outgoing_call.assert_called_once()
|
mock_protocol.outgoing_call.assert_called_once()
|
||||||
assert (
|
assert (
|
||||||
mock_protocol.outgoing_call.call_args.kwargs["destination"].host
|
mock_protocol.outgoing_call.call_args.kwargs["destination"].host
|
||||||
@ -910,7 +940,11 @@ async def test_voip_id_is_ip_address(
|
|||||||
|
|
||||||
# Trigger announcement
|
# Trigger announcement
|
||||||
satellite.on_chunk(bytes(_ONE_SECOND))
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||||
async with asyncio.timeout(1):
|
await asyncio.sleep(0.2)
|
||||||
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||||
|
await asyncio.sleep(0.2)
|
||||||
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||||
|
async with asyncio.timeout(2):
|
||||||
await announce_task
|
await announce_task
|
||||||
|
|
||||||
mock_send_tts.assert_called_once_with(
|
mock_send_tts.assert_called_once_with(
|
||||||
@ -955,7 +989,7 @@ async def test_announce_timeout(
|
|||||||
0.01,
|
0.01,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
satellite.transport = Mock()
|
satellite.connection_made(Mock())
|
||||||
with pytest.raises(TimeoutError):
|
with pytest.raises(TimeoutError):
|
||||||
await satellite.async_announce(announcement)
|
await satellite.async_announce(announcement)
|
||||||
|
|
||||||
@ -1042,7 +1076,7 @@ async def test_start_conversation(
|
|||||||
new=async_pipeline_from_audio_stream,
|
new=async_pipeline_from_audio_stream,
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
satellite.transport = Mock()
|
satellite.connection_made(Mock())
|
||||||
conversation_task = hass.async_create_background_task(
|
conversation_task = hass.async_create_background_task(
|
||||||
satellite.async_start_conversation(announcement), "voip_start_conversation"
|
satellite.async_start_conversation(announcement), "voip_start_conversation"
|
||||||
)
|
)
|
||||||
@ -1051,16 +1085,20 @@ async def test_start_conversation(
|
|||||||
|
|
||||||
# Trigger announcement and wait for it to finish
|
# Trigger announcement and wait for it to finish
|
||||||
satellite.on_chunk(bytes(_ONE_SECOND))
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||||
async with asyncio.timeout(1):
|
await asyncio.sleep(0.2)
|
||||||
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||||
|
await asyncio.sleep(0.2)
|
||||||
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||||
|
async with asyncio.timeout(2):
|
||||||
await tts_sent.wait()
|
await tts_sent.wait()
|
||||||
|
|
||||||
tts_sent.clear()
|
|
||||||
|
|
||||||
# Trigger pipeline
|
# Trigger pipeline
|
||||||
satellite.on_chunk(bytes(_ONE_SECOND))
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||||
async with asyncio.timeout(1):
|
await asyncio.sleep(0.2)
|
||||||
# Wait for TTS
|
satellite.on_chunk(bytes(_ONE_SECOND))
|
||||||
await tts_sent.wait()
|
await asyncio.sleep(3)
|
||||||
|
async with asyncio.timeout(3):
|
||||||
|
# Wait for Conversation end
|
||||||
await conversation_task
|
await conversation_task
|
||||||
|
|
||||||
|
|
||||||
@ -1073,21 +1111,8 @@ async def test_start_conversation_user_doesnt_pick_up(
|
|||||||
"""Test start conversation when the user doesn't pick up."""
|
"""Test start conversation when the user doesn't pick up."""
|
||||||
assert await async_setup_component(hass, "voip", {})
|
assert await async_setup_component(hass, "voip", {})
|
||||||
|
|
||||||
pipeline = assist_pipeline.Pipeline(
|
|
||||||
conversation_engine="test engine",
|
|
||||||
conversation_language="en",
|
|
||||||
language="en",
|
|
||||||
name="test pipeline",
|
|
||||||
stt_engine="test stt",
|
|
||||||
stt_language="en",
|
|
||||||
tts_engine="test tts",
|
|
||||||
tts_language="en",
|
|
||||||
tts_voice=None,
|
|
||||||
wake_word_entity=None,
|
|
||||||
wake_word_id=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
satellite = async_get_satellite_entity(hass, voip.DOMAIN, voip_device.voip_id)
|
||||||
|
satellite.addr = ("192.168.1.1", 12345)
|
||||||
assert isinstance(satellite, VoipAssistSatellite)
|
assert isinstance(satellite, VoipAssistSatellite)
|
||||||
assert (
|
assert (
|
||||||
satellite.supported_features
|
satellite.supported_features
|
||||||
@ -1098,62 +1123,22 @@ async def test_start_conversation_user_doesnt_pick_up(
|
|||||||
mock_protocol: AsyncMock = hass.data[DOMAIN].protocol
|
mock_protocol: AsyncMock = hass.data[DOMAIN].protocol
|
||||||
mock_protocol.outgoing_call = Mock()
|
mock_protocol.outgoing_call = Mock()
|
||||||
|
|
||||||
pipeline_started = asyncio.Event()
|
announcement = assist_satellite.AssistSatelliteAnnouncement(
|
||||||
|
message="test announcement",
|
||||||
async def async_pipeline_from_audio_stream(
|
media_id=_MEDIA_ID,
|
||||||
hass: HomeAssistant,
|
tts_token="test-token",
|
||||||
context: Context,
|
original_media_id=_MEDIA_ID,
|
||||||
*args,
|
media_id_source="tts",
|
||||||
conversation_extra_system_prompt: str | None = None,
|
)
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
# System prompt should be not be set due to timeout (user not picking up)
|
|
||||||
assert conversation_extra_system_prompt is None
|
|
||||||
|
|
||||||
pipeline_started.set()
|
|
||||||
|
|
||||||
|
# Very short timeout which will trigger because we don't send any audio in
|
||||||
with (
|
with (
|
||||||
patch(
|
patch(
|
||||||
"homeassistant.components.assist_satellite.entity.async_get_pipeline",
|
"homeassistant.components.voip.assist_satellite._ANNOUNCEMENT_RING_TIMEOUT",
|
||||||
return_value=pipeline,
|
0.1,
|
||||||
),
|
|
||||||
patch(
|
|
||||||
"homeassistant.components.voip.assist_satellite.VoipAssistSatellite.async_start_conversation",
|
|
||||||
side_effect=TimeoutError,
|
|
||||||
),
|
|
||||||
patch(
|
|
||||||
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
|
|
||||||
new=async_pipeline_from_audio_stream,
|
|
||||||
),
|
|
||||||
patch(
|
|
||||||
"homeassistant.components.tts.generate_media_source_id",
|
|
||||||
return_value="media-source://bla",
|
|
||||||
),
|
|
||||||
patch(
|
|
||||||
"homeassistant.components.tts.async_resolve_engine",
|
|
||||||
return_value="test tts",
|
|
||||||
),
|
|
||||||
patch(
|
|
||||||
"homeassistant.components.tts.async_create_stream",
|
|
||||||
return_value=MockResultStream(hass, "wav", b""),
|
|
||||||
),
|
),
|
||||||
):
|
):
|
||||||
satellite.transport = Mock()
|
satellite.connection_made(Mock())
|
||||||
|
|
||||||
# Error should clear system prompt
|
|
||||||
with pytest.raises(TimeoutError):
|
with pytest.raises(TimeoutError):
|
||||||
await hass.services.async_call(
|
await satellite.async_start_conversation(announcement)
|
||||||
assist_satellite.DOMAIN,
|
|
||||||
"start_conversation",
|
|
||||||
{
|
|
||||||
"entity_id": satellite.entity_id,
|
|
||||||
"start_message": "test announcement",
|
|
||||||
"extra_system_prompt": "test prompt",
|
|
||||||
},
|
|
||||||
blocking=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Trigger a pipeline so we can check if the system prompt was cleared
|
|
||||||
satellite.on_chunk(bytes(_ONE_SECOND))
|
|
||||||
async with asyncio.timeout(1):
|
|
||||||
await pipeline_started.wait()
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user