diff --git a/homeassistant/components/voip/assist_satellite.py b/homeassistant/components/voip/assist_satellite.py index 2c0a3b9641a..6c63710a5b1 100644 --- a/homeassistant/components/voip/assist_satellite.py +++ b/homeassistant/components/voip/assist_satellite.py @@ -408,10 +408,18 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol """Play an announcement once.""" _LOGGER.debug("Playing announcement") - try: - await asyncio.sleep(_ANNOUNCEMENT_BEFORE_DELAY) - await self._send_tts(announcement.original_media_id, wait_for_tone=False) + if announcement.tts_token is None: + _LOGGER.error("Only TTS announcements are supported") + return + await asyncio.sleep(_ANNOUNCEMENT_BEFORE_DELAY) + stream = tts.async_get_stream(self.hass, announcement.tts_token) + if stream is None: + _LOGGER.error("TTS stream no longer available") + return + + try: + await self._send_tts(stream, wait_for_tone=False) if not self._run_pipeline_after_announce: # Delay before looping announcement await asyncio.sleep(_ANNOUNCEMENT_AFTER_DELAY) @@ -442,11 +450,14 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol ) elif event.type == PipelineEventType.TTS_END: # Send TTS audio to caller over RTP - if event.data and (tts_output := event.data["tts_output"]): - media_id = tts_output["media_id"] + if ( + event.data + and (tts_output := event.data["tts_output"]) + and (stream := tts.async_get_stream(self.hass, tts_output["token"])) + ): self.config_entry.async_create_background_task( self.hass, - self._send_tts(media_id), + self._send_tts(tts_stream=stream), "voip_pipeline_tts", ) else: @@ -457,19 +468,22 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol self._pipeline_had_error = True _LOGGER.warning(event) - async def _send_tts(self, media_id: str, wait_for_tone: bool = True) -> None: + async def _send_tts( + self, + tts_stream: tts.ResultStream, + wait_for_tone: bool = True, + ) -> None: """Send TTS audio to caller via RTP.""" try: if self.transport is None: return # not connected - extension, data = await tts.async_get_media_source_audio( - self.hass, - media_id, - ) + data = b"".join([chunk async for chunk in tts_stream.async_stream_result()]) - if extension != "wav": - raise ValueError(f"Only WAV audio can be streamed, got {extension}") + if tts_stream.extension != "wav": + raise ValueError( + f"Only TTS WAV audio can be streamed, got {tts_stream.extension}" + ) if wait_for_tone and ((self._tones & Tones.PROCESSING) == Tones.PROCESSING): # Don't overlap TTS and processing beep diff --git a/tests/components/voip/test_voip.py b/tests/components/voip/test_voip.py index 7ac76227a1b..345f0399645 100644 --- a/tests/components/voip/test_voip.py +++ b/tests/components/voip/test_voip.py @@ -38,12 +38,12 @@ def mock_tts_cache_dir_autouse(mock_tts_cache_dir: Path) -> None: """Mock the TTS cache dir with empty dir.""" -def _empty_wav() -> bytes: +def _empty_wav(framerate=16000) -> bytes: """Return bytes of an empty WAV file.""" with io.BytesIO() as wav_io: wav_file: wave.Wave_write = wave.open(wav_io, "wb") with wav_file: - wav_file.setframerate(16000) + wav_file.setframerate(framerate) wav_file.setsampwidth(2) wav_file.setnchannels(1) @@ -307,10 +307,11 @@ async def test_pipeline( assert satellite.state == AssistSatelliteState.RESPONDING # Proceed with media output + mock_tts_result_stream = MockResultStream(hass, "wav", _empty_wav()) event_callback( assist_pipeline.PipelineEvent( type=assist_pipeline.PipelineEventType.TTS_END, - data={"tts_output": {"media_id": _MEDIA_ID}}, + data={"tts_output": {"token": mock_tts_result_stream.token}}, ) ) @@ -326,22 +327,11 @@ async def test_pipeline( original_tts_response_finished() done.set() - async def async_get_media_source_audio( - hass: HomeAssistant, - media_source_id: str, - ) -> tuple[str, bytes]: - assert media_source_id == _MEDIA_ID - return ("wav", _empty_wav()) - with ( patch( "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream", new=async_pipeline_from_audio_stream, ), - patch( - "homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio", - new=async_get_media_source_audio, - ), patch.object(satellite, "tts_response_finished", tts_response_finished), ): satellite._tones = Tones(0) @@ -457,10 +447,11 @@ async def test_tts_timeout( ) # Proceed with media output + mock_tts_result_stream = MockResultStream(hass, "wav", _empty_wav()) event_callback( assist_pipeline.PipelineEvent( type=assist_pipeline.PipelineEventType.TTS_END, - data={"tts_output": {"media_id": _MEDIA_ID}}, + data={"tts_output": {"token": mock_tts_result_stream.token}}, ) ) @@ -474,22 +465,9 @@ async def test_tts_timeout( # Block here to force a timeout in _send_tts await asyncio.sleep(2) - async def async_get_media_source_audio( - hass: HomeAssistant, - media_source_id: str, - ) -> tuple[str, bytes]: - # Should time out immediately - return ("wav", _empty_wav()) - - with ( - patch( - "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream", - new=async_pipeline_from_audio_stream, - ), - patch( - "homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio", - new=async_get_media_source_audio, - ), + with patch( + "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream", + new=async_pipeline_from_audio_stream, ): satellite._tts_extra_timeout = 0.001 for tone in Tones: @@ -568,29 +546,18 @@ async def test_tts_wrong_extension( ) # Proceed with media output + # Should fail because it's not "wav" + mock_tts_result_stream = MockResultStream(hass, "mp3", b"") event_callback( assist_pipeline.PipelineEvent( type=assist_pipeline.PipelineEventType.TTS_END, - data={"tts_output": {"media_id": _MEDIA_ID}}, + data={"tts_output": {"token": mock_tts_result_stream.token}}, ) ) - async def async_get_media_source_audio( - hass: HomeAssistant, - media_source_id: str, - ) -> tuple[str, bytes]: - # Should fail because it's not "wav" - return ("mp3", b"") - - with ( - patch( - "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream", - new=async_pipeline_from_audio_stream, - ), - patch( - "homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio", - new=async_get_media_source_audio, - ), + with patch( + "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream", + new=async_pipeline_from_audio_stream, ): satellite.transport = Mock() @@ -663,36 +630,18 @@ async def test_tts_wrong_wav_format( ) # Proceed with media output + # Should fail because it's not 16Khz + mock_tts_result_stream = MockResultStream(hass, "wav", _empty_wav(22050)) event_callback( assist_pipeline.PipelineEvent( type=assist_pipeline.PipelineEventType.TTS_END, - data={"tts_output": {"media_id": _MEDIA_ID}}, + data={"tts_output": {"token": mock_tts_result_stream.token}}, ) ) - async def async_get_media_source_audio( - hass: HomeAssistant, - media_source_id: str, - ) -> tuple[str, bytes]: - # Should fail because it's not 16Khz, 16-bit mono - with io.BytesIO() as wav_io: - wav_file: wave.Wave_write = wave.open(wav_io, "wb") - with wav_file: - wav_file.setframerate(22050) - wav_file.setsampwidth(2) - wav_file.setnchannels(2) - - return ("wav", wav_io.getvalue()) - - with ( - patch( - "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream", - new=async_pipeline_from_audio_stream, - ), - patch( - "homeassistant.components.voip.assist_satellite.tts.async_get_media_source_audio", - new=async_get_media_source_audio, - ), + with patch( + "homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream", + new=async_pipeline_from_audio_stream, ): satellite.transport = Mock() @@ -878,10 +827,11 @@ async def test_announce( assert err.value.translation_domain == "voip" assert err.value.translation_key == "non_tts_announcement" + mock_tts_result_stream = MockResultStream(hass, "wav", _empty_wav()) announcement = assist_satellite.AssistSatelliteAnnouncement( message="test announcement", media_id=_MEDIA_ID, - tts_token="test-token", + tts_token=mock_tts_result_stream.token, original_media_id=_MEDIA_ID, media_id_source="tts", ) @@ -907,7 +857,9 @@ async def test_announce( async with asyncio.timeout(1): await announce_task - mock_send_tts.assert_called_once_with(_MEDIA_ID, wait_for_tone=False) + mock_send_tts.assert_called_once_with( + mock_tts_result_stream, wait_for_tone=False + ) @pytest.mark.usefixtures("socket_enabled") @@ -926,10 +878,11 @@ async def test_voip_id_is_ip_address( & assist_satellite.AssistSatelliteEntityFeature.ANNOUNCE ) + mock_tts_result_stream = MockResultStream(hass, "wav", _empty_wav()) announcement = assist_satellite.AssistSatelliteAnnouncement( message="test announcement", media_id=_MEDIA_ID, - tts_token="test-token", + tts_token=mock_tts_result_stream.token, original_media_id=_MEDIA_ID, media_id_source="tts", ) @@ -960,7 +913,9 @@ async def test_voip_id_is_ip_address( async with asyncio.timeout(1): await announce_task - mock_send_tts.assert_called_once_with(_MEDIA_ID, wait_for_tone=False) + mock_send_tts.assert_called_once_with( + mock_tts_result_stream, wait_for_tone=False + ) @pytest.mark.usefixtures("socket_enabled") @@ -979,10 +934,11 @@ async def test_announce_timeout( & assist_satellite.AssistSatelliteEntityFeature.ANNOUNCE ) + mock_tts_result_stream = MockResultStream(hass, "wav", _empty_wav()) announcement = assist_satellite.AssistSatelliteAnnouncement( message="test announcement", media_id=_MEDIA_ID, - tts_token="test-token", + tts_token=mock_tts_result_stream.token, original_media_id=_MEDIA_ID, media_id_source="tts", ) @@ -1020,10 +976,11 @@ async def test_start_conversation( & assist_satellite.AssistSatelliteEntityFeature.START_CONVERSATION ) + mock_tts_result_stream = MockResultStream(hass, "wav", _empty_wav()) announcement = assist_satellite.AssistSatelliteAnnouncement( message="test announcement", media_id=_MEDIA_ID, - tts_token="test-token", + tts_token=mock_tts_result_stream.token, original_media_id=_MEDIA_ID, media_id_source="tts", ) @@ -1061,10 +1018,11 @@ async def test_start_conversation( ) # Proceed with media output + mock_tts_result_stream = MockResultStream(hass, "wav", _empty_wav()) event_callback( assist_pipeline.PipelineEvent( type=assist_pipeline.PipelineEventType.TTS_END, - data={"tts_output": {"media_id": _MEDIA_ID}}, + data={"tts_output": {"token": mock_tts_result_stream.token}}, ) )