Set device id and forward errors to Wyoming satellites (#105266)

* Set device id and forward errors

* Fix tests
This commit is contained in:
Michael Hansen 2023-12-07 19:44:43 -06:00 committed by GitHub
parent e9f8e7ab50
commit 43daeb2630
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 65 additions and 5 deletions

View File

@ -6,6 +6,6 @@
"dependencies": ["assist_pipeline"], "dependencies": ["assist_pipeline"],
"documentation": "https://www.home-assistant.io/integrations/wyoming", "documentation": "https://www.home-assistant.io/integrations/wyoming",
"iot_class": "local_push", "iot_class": "local_push",
"requirements": ["wyoming==1.3.0"], "requirements": ["wyoming==1.4.0"],
"zeroconf": ["_wyoming._tcp.local."] "zeroconf": ["_wyoming._tcp.local."]
} }

View File

@ -9,6 +9,7 @@ import wave
from wyoming.asr import Transcribe, Transcript from wyoming.asr import Transcribe, Transcript
from wyoming.audio import AudioChunk, AudioChunkConverter, AudioStart, AudioStop from wyoming.audio import AudioChunk, AudioChunkConverter, AudioStart, AudioStop
from wyoming.client import AsyncTcpClient from wyoming.client import AsyncTcpClient
from wyoming.error import Error
from wyoming.pipeline import PipelineStage, RunPipeline from wyoming.pipeline import PipelineStage, RunPipeline
from wyoming.satellite import RunSatellite from wyoming.satellite import RunSatellite
from wyoming.tts import Synthesize, SynthesizeVoice from wyoming.tts import Synthesize, SynthesizeVoice
@ -239,6 +240,7 @@ class WyomingSatellite:
auto_gain_dbfs=self.device.auto_gain, auto_gain_dbfs=self.device.auto_gain,
volume_multiplier=self.device.volume_multiplier, volume_multiplier=self.device.volume_multiplier,
), ),
device_id=self.device.device_id,
) )
) )
@ -333,6 +335,16 @@ class WyomingSatellite:
if event.data and (tts_output := event.data["tts_output"]): if event.data and (tts_output := event.data["tts_output"]):
media_id = tts_output["media_id"] media_id = tts_output["media_id"]
self.hass.add_job(self._stream_tts(media_id)) self.hass.add_job(self._stream_tts(media_id))
elif event.type == assist_pipeline.PipelineEventType.ERROR:
# Pipeline error
if event.data:
self.hass.add_job(
self._client.write_event(
Error(
text=event.data["message"], code=event.data["code"]
).event()
)
)
async def _connect(self) -> None: async def _connect(self) -> None:
"""Connect to satellite over TCP.""" """Connect to satellite over TCP."""

View File

@ -2760,7 +2760,7 @@ wled==0.17.0
wolf-smartset==0.1.11 wolf-smartset==0.1.11
# homeassistant.components.wyoming # homeassistant.components.wyoming
wyoming==1.3.0 wyoming==1.4.0
# homeassistant.components.xbox # homeassistant.components.xbox
xbox-webapi==2.0.11 xbox-webapi==2.0.11

View File

@ -2064,7 +2064,7 @@ wled==0.17.0
wolf-smartset==0.1.11 wolf-smartset==0.1.11
# homeassistant.components.wyoming # homeassistant.components.wyoming
wyoming==1.3.0 wyoming==1.4.0
# homeassistant.components.xbox # homeassistant.components.xbox
xbox-webapi==2.0.11 xbox-webapi==2.0.11

View File

@ -6,7 +6,7 @@
'language': 'en', 'language': 'en',
}), }),
'payload': None, 'payload': None,
'type': 'transcibe', 'type': 'transcribe',
}), }),
dict({ dict({
'data': dict({ 'data': dict({

View File

@ -8,6 +8,7 @@ import wave
from wyoming.asr import Transcribe, Transcript from wyoming.asr import Transcribe, Transcript
from wyoming.audio import AudioChunk, AudioStart, AudioStop from wyoming.audio import AudioChunk, AudioStart, AudioStop
from wyoming.error import Error
from wyoming.event import Event from wyoming.event import Event
from wyoming.pipeline import PipelineStage, RunPipeline from wyoming.pipeline import PipelineStage, RunPipeline
from wyoming.satellite import RunSatellite from wyoming.satellite import RunSatellite
@ -96,6 +97,9 @@ class SatelliteAsyncTcpClient(MockAsyncTcpClient):
self.tts_audio_stop_event = asyncio.Event() self.tts_audio_stop_event = asyncio.Event()
self.tts_audio_chunk: AudioChunk | None = None self.tts_audio_chunk: AudioChunk | None = None
self.error_event = asyncio.Event()
self.error: Error | None = None
self._mic_audio_chunk = AudioChunk( self._mic_audio_chunk = AudioChunk(
rate=16000, width=2, channels=1, audio=b"chunk" rate=16000, width=2, channels=1, audio=b"chunk"
).event() ).event()
@ -135,6 +139,9 @@ class SatelliteAsyncTcpClient(MockAsyncTcpClient):
self.tts_audio_chunk_event.set() self.tts_audio_chunk_event.set()
elif AudioStop.is_type(event.type): elif AudioStop.is_type(event.type):
self.tts_audio_stop_event.set() self.tts_audio_stop_event.set()
elif Error.is_type(event.type):
self.error = Error.from_event(event)
self.error_event.set()
async def read_event(self) -> Event | None: async def read_event(self) -> Event | None:
"""Receive.""" """Receive."""
@ -175,8 +182,9 @@ async def test_satellite_pipeline(hass: HomeAssistant) -> None:
await mock_client.connect_event.wait() await mock_client.connect_event.wait()
await mock_client.run_satellite_event.wait() await mock_client.run_satellite_event.wait()
mock_run_pipeline.assert_called() mock_run_pipeline.assert_called_once()
event_callback = mock_run_pipeline.call_args.kwargs["event_callback"] event_callback = mock_run_pipeline.call_args.kwargs["event_callback"]
assert mock_run_pipeline.call_args.kwargs.get("device_id") == device.device_id
# Start detecting wake word # Start detecting wake word
event_callback( event_callback(
@ -458,3 +466,43 @@ async def test_satellite_disconnect_during_pipeline(hass: HomeAssistant) -> None
# Sensor should have been turned off # Sensor should have been turned off
assert not device.is_active assert not device.is_active
async def test_satellite_error_during_pipeline(hass: HomeAssistant) -> None:
"""Test satellite error occurring during pipeline run."""
events = [
RunPipeline(
start_stage=PipelineStage.WAKE, end_stage=PipelineStage.TTS
).event(),
] # no audio chunks after RunPipeline
with patch(
"homeassistant.components.wyoming.data.load_wyoming_info",
return_value=SATELLITE_INFO,
), patch(
"homeassistant.components.wyoming.satellite.AsyncTcpClient",
SatelliteAsyncTcpClient(events),
) as mock_client, patch(
"homeassistant.components.wyoming.satellite.assist_pipeline.async_pipeline_from_audio_stream",
) as mock_run_pipeline:
await setup_config_entry(hass)
async with asyncio.timeout(1):
await mock_client.connect_event.wait()
await mock_client.run_satellite_event.wait()
mock_run_pipeline.assert_called_once()
event_callback = mock_run_pipeline.call_args.kwargs["event_callback"]
event_callback(
assist_pipeline.PipelineEvent(
assist_pipeline.PipelineEventType.ERROR,
{"code": "test code", "message": "test message"},
)
)
async with asyncio.timeout(1):
await mock_client.error_event.wait()
assert mock_client.error is not None
assert mock_client.error.text == "test message"
assert mock_client.error.code == "test code"