From 43daeb26303208c0cf2183d548a705f840fa1fb5 Mon Sep 17 00:00:00 2001 From: Michael Hansen Date: Thu, 7 Dec 2023 19:44:43 -0600 Subject: [PATCH] Set device id and forward errors to Wyoming satellites (#105266) * Set device id and forward errors * Fix tests --- .../components/wyoming/manifest.json | 2 +- homeassistant/components/wyoming/satellite.py | 12 +++++ requirements_all.txt | 2 +- requirements_test_all.txt | 2 +- .../wyoming/snapshots/test_stt.ambr | 2 +- tests/components/wyoming/test_satellite.py | 50 ++++++++++++++++++- 6 files changed, 65 insertions(+), 5 deletions(-) diff --git a/homeassistant/components/wyoming/manifest.json b/homeassistant/components/wyoming/manifest.json index 540aaa9aeac..7174683fd18 100644 --- a/homeassistant/components/wyoming/manifest.json +++ b/homeassistant/components/wyoming/manifest.json @@ -6,6 +6,6 @@ "dependencies": ["assist_pipeline"], "documentation": "https://www.home-assistant.io/integrations/wyoming", "iot_class": "local_push", - "requirements": ["wyoming==1.3.0"], + "requirements": ["wyoming==1.4.0"], "zeroconf": ["_wyoming._tcp.local."] } diff --git a/homeassistant/components/wyoming/satellite.py b/homeassistant/components/wyoming/satellite.py index 1cc3fde2a9c..94f61c17047 100644 --- a/homeassistant/components/wyoming/satellite.py +++ b/homeassistant/components/wyoming/satellite.py @@ -9,6 +9,7 @@ import wave from wyoming.asr import Transcribe, Transcript from wyoming.audio import AudioChunk, AudioChunkConverter, AudioStart, AudioStop from wyoming.client import AsyncTcpClient +from wyoming.error import Error from wyoming.pipeline import PipelineStage, RunPipeline from wyoming.satellite import RunSatellite from wyoming.tts import Synthesize, SynthesizeVoice @@ -239,6 +240,7 @@ class WyomingSatellite: auto_gain_dbfs=self.device.auto_gain, volume_multiplier=self.device.volume_multiplier, ), + device_id=self.device.device_id, ) ) @@ -333,6 +335,16 @@ class WyomingSatellite: if event.data and (tts_output := event.data["tts_output"]): media_id = tts_output["media_id"] self.hass.add_job(self._stream_tts(media_id)) + elif event.type == assist_pipeline.PipelineEventType.ERROR: + # Pipeline error + if event.data: + self.hass.add_job( + self._client.write_event( + Error( + text=event.data["message"], code=event.data["code"] + ).event() + ) + ) async def _connect(self) -> None: """Connect to satellite over TCP.""" diff --git a/requirements_all.txt b/requirements_all.txt index e4fcb0cb396..935f5f78075 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -2760,7 +2760,7 @@ wled==0.17.0 wolf-smartset==0.1.11 # homeassistant.components.wyoming -wyoming==1.3.0 +wyoming==1.4.0 # homeassistant.components.xbox xbox-webapi==2.0.11 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index e46bcbbe862..741b40b5ee4 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -2064,7 +2064,7 @@ wled==0.17.0 wolf-smartset==0.1.11 # homeassistant.components.wyoming -wyoming==1.3.0 +wyoming==1.4.0 # homeassistant.components.xbox xbox-webapi==2.0.11 diff --git a/tests/components/wyoming/snapshots/test_stt.ambr b/tests/components/wyoming/snapshots/test_stt.ambr index 784f89b2ab8..b45b7508b28 100644 --- a/tests/components/wyoming/snapshots/test_stt.ambr +++ b/tests/components/wyoming/snapshots/test_stt.ambr @@ -6,7 +6,7 @@ 'language': 'en', }), 'payload': None, - 'type': 'transcibe', + 'type': 'transcribe', }), dict({ 'data': dict({ diff --git a/tests/components/wyoming/test_satellite.py b/tests/components/wyoming/test_satellite.py index 06ae337a19c..50252007aa5 100644 --- a/tests/components/wyoming/test_satellite.py +++ b/tests/components/wyoming/test_satellite.py @@ -8,6 +8,7 @@ import wave from wyoming.asr import Transcribe, Transcript from wyoming.audio import AudioChunk, AudioStart, AudioStop +from wyoming.error import Error from wyoming.event import Event from wyoming.pipeline import PipelineStage, RunPipeline from wyoming.satellite import RunSatellite @@ -96,6 +97,9 @@ class SatelliteAsyncTcpClient(MockAsyncTcpClient): self.tts_audio_stop_event = asyncio.Event() self.tts_audio_chunk: AudioChunk | None = None + self.error_event = asyncio.Event() + self.error: Error | None = None + self._mic_audio_chunk = AudioChunk( rate=16000, width=2, channels=1, audio=b"chunk" ).event() @@ -135,6 +139,9 @@ class SatelliteAsyncTcpClient(MockAsyncTcpClient): self.tts_audio_chunk_event.set() elif AudioStop.is_type(event.type): self.tts_audio_stop_event.set() + elif Error.is_type(event.type): + self.error = Error.from_event(event) + self.error_event.set() async def read_event(self) -> Event | None: """Receive.""" @@ -175,8 +182,9 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None: await mock_client.connect_event.wait() await mock_client.run_satellite_event.wait() - mock_run_pipeline.assert_called() + mock_run_pipeline.assert_called_once() event_callback = mock_run_pipeline.call_args.kwargs["event_callback"] + assert mock_run_pipeline.call_args.kwargs.get("device_id") == device.device_id # Start detecting wake word event_callback( @@ -458,3 +466,43 @@ async def test_satellite_disconnect_during_pipeline(hass: HomeAssistant) -> None # Sensor should have been turned off assert not device.is_active + + +async def test_satellite_error_during_pipeline(hass: HomeAssistant) -> None: + """Test satellite error occurring during pipeline run.""" + events = [ + RunPipeline( + start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS + ).event(), + ] # no audio chunks after RunPipeline + + with patch( + "homeassistant.components.wyoming.data.load_wyoming_info", + return_value=SATELLITE_INFO, + ), patch( + "homeassistant.components.wyoming.satellite.AsyncTcpClient", + SatelliteAsyncTcpClient(events), + ) as mock_client, patch( + "homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream", + ) as mock_run_pipeline: + await setup_config_entry(hass) + + async with asyncio.timeout(1): + await mock_client.connect_event.wait() + await mock_client.run_satellite_event.wait() + + mock_run_pipeline.assert_called_once() + event_callback = mock_run_pipeline.call_args.kwargs["event_callback"] + event_callback( + assist_pipeline.PipelineEvent( + assist_pipeline.PipelineEventType.ERROR, + {"code": "test code", "message": "test message"}, + ) + ) + + async with asyncio.timeout(1): + await mock_client.error_event.wait() + + assert mock_client.error is not None + assert mock_client.error.text == "test message" + assert mock_client.error.code == "test code"