More thorough checks in ESPHome voice assistant UDP server (#109394)

* More thorough checks in UDP server

* Simplify and change to stop_requested

* Check transport
This commit is contained in:
Michael Hansen 2024-02-02 20:26:44 -06:00 committed by GitHub
parent ae210886c1
commit 3347a3f8a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 122 additions and 62 deletions

View File

@ -352,7 +352,6 @@ class ESPHomeManager:
if self.voice_assistant_udp_server is not None:
_LOGGER.warning("Voice assistant UDP server was not stopped")
self.voice_assistant_udp_server.stop()
self.voice_assistant_udp_server.close()
self.voice_assistant_udp_server = None
hass = self.hass

View File

@ -1,4 +1,5 @@
"""ESPHome voice assistant support."""
from __future__ import annotations
import asyncio
@ -67,7 +68,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
"""Receive UDP packets and forward them to the voice assistant."""
started = False
stopped = False
stop_requested = False
transport: asyncio.DatagramTransport | None = None
remote_addr: tuple[str, int] | None = None
@ -92,6 +93,11 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
self._tts_done = asyncio.Event()
self._tts_task: asyncio.Task | None = None
@property
def is_running(self) -> bool:
"""True if the the UDP server is started and hasn't been asked to stop."""
return self.started and (not self.stop_requested)
async def start_server(self) -> int:
"""Start accepting connections."""
@ -99,7 +105,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
"""Accept connection."""
if self.started:
raise RuntimeError("Can only start once")
if self.stopped:
if self.stop_requested:
raise RuntimeError("No longer accepting connections")
self.started = True
@ -124,7 +130,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
@callback
def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None:
"""Handle incoming UDP packet."""
if not self.started or self.stopped:
if not self.is_running:
return
if self.remote_addr is None:
self.remote_addr = addr
@ -142,19 +148,19 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
def stop(self) -> None:
"""Stop the receiver."""
self.queue.put_nowait(b"")
self.started = False
self.stopped = True
self.close()
def close(self) -> None:
"""Close the receiver."""
self.started = False
self.stopped = True
self.stop_requested = True
if self.transport is not None:
self.transport.close()
async def _iterate_packets(self) -> AsyncIterable[bytes]:
"""Iterate over incoming packets."""
if not self.started or self.stopped:
if not self.is_running:
raise RuntimeError("Not running")
while data := await self.queue.get():
@ -303,8 +309,11 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
async def _send_tts(self, media_id: str) -> None:
"""Send TTS audio to device via UDP."""
# Always send stream start/end events
self.handle_event(VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START, {})
try:
if self.transport is None:
if (not self.is_running) or (self.transport is None):
return
extension, data = await tts.async_get_media_source_audio(
@ -337,15 +346,11 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
_LOGGER.debug("Sending %d bytes of audio", audio_bytes_size)
self.handle_event(
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START, {}
)
bytes_per_sample = stt.AudioBitRates.BITRATE_16 // 8
sample_offset = 0
samples_left = audio_bytes_size // bytes_per_sample
while samples_left > 0:
while (samples_left > 0) and self.is_running:
bytes_offset = sample_offset * bytes_per_sample
chunk: bytes = audio_bytes[bytes_offset : bytes_offset + 1024]
samples_in_chunk = len(chunk) // bytes_per_sample

View File

@ -70,6 +70,19 @@ def voice_assistant_udp_server_v2(
return voice_assistant_udp_server(entry=mock_voice_assistant_v2_entry)
@pytest.fixture
def test_wav() -> bytes:
"""Return one second of empty WAV audio."""
with io.BytesIO() as wav_io:
with wave.open(wav_io, "wb") as wav_file:
wav_file.setframerate(16000)
wav_file.setsampwidth(2)
wav_file.setnchannels(1)
wav_file.writeframes(bytes(_ONE_SECOND))
return wav_io.getvalue()
async def test_pipeline_events(
hass: HomeAssistant,
voice_assistant_udp_server_v1: VoiceAssistantUDPServer,
@ -241,11 +254,13 @@ async def test_udp_server_multiple(
):
await voice_assistant_udp_server_v1.start_server()
with patch(
"homeassistant.components.esphome.voice_assistant.UDP_PORT",
new=unused_udp_port_factory(),
), pytest.raises(RuntimeError):
pass
with (
patch(
"homeassistant.components.esphome.voice_assistant.UDP_PORT",
new=unused_udp_port_factory(),
),
pytest.raises(RuntimeError),
):
await voice_assistant_udp_server_v1.start_server()
@ -257,10 +272,13 @@ async def test_udp_server_after_stopped(
) -> None:
"""Test that the UDP server raises an error if started after stopped."""
voice_assistant_udp_server_v1.close()
with patch(
"homeassistant.components.esphome.voice_assistant.UDP_PORT",
new=unused_udp_port_factory(),
), pytest.raises(RuntimeError):
with (
patch(
"homeassistant.components.esphome.voice_assistant.UDP_PORT",
new=unused_udp_port_factory(),
),
pytest.raises(RuntimeError),
):
await voice_assistant_udp_server_v1.start_server()
@ -362,35 +380,33 @@ async def test_send_tts_not_called_when_empty(
async def test_send_tts(
hass: HomeAssistant,
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
test_wav,
) -> None:
"""Test the UDP server calls sendto to transmit audio data to device."""
with io.BytesIO() as wav_io:
with wave.open(wav_io, "wb") as wav_file:
wav_file.setframerate(16000)
wav_file.setsampwidth(2)
wav_file.setnchannels(1)
wav_file.writeframes(bytes(_ONE_SECOND))
wav_bytes = wav_io.getvalue()
with patch(
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
return_value=("wav", wav_bytes),
return_value=("wav", test_wav),
):
voice_assistant_udp_server_v2.started = True
voice_assistant_udp_server_v2.transport = Mock(spec=asyncio.DatagramTransport)
voice_assistant_udp_server_v2._event_callback(
PipelineEvent(
type=PipelineEventType.TTS_END,
data={
"tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL}
},
with patch.object(
voice_assistant_udp_server_v2.transport, "is_closing", return_value=False
):
voice_assistant_udp_server_v2._event_callback(
PipelineEvent(
type=PipelineEventType.TTS_END,
data={
"tts_output": {
"media_id": _TEST_MEDIA_ID,
"url": _TEST_OUTPUT_URL,
}
},
)
)
)
await voice_assistant_udp_server_v2._tts_done.wait()
await voice_assistant_udp_server_v2._tts_done.wait()
voice_assistant_udp_server_v2.transport.sendto.assert_called()
voice_assistant_udp_server_v2.transport.sendto.assert_called()
async def test_send_tts_wrong_sample_rate(
@ -400,17 +416,20 @@ async def test_send_tts_wrong_sample_rate(
"""Test the UDP server calls sendto to transmit audio data to device."""
with io.BytesIO() as wav_io:
with wave.open(wav_io, "wb") as wav_file:
wav_file.setframerate(22050) # should be 16000
wav_file.setframerate(22050)
wav_file.setsampwidth(2)
wav_file.setnchannels(1)
wav_file.writeframes(bytes(_ONE_SECOND))
wav_bytes = wav_io.getvalue()
with patch(
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
return_value=("wav", wav_bytes),
), pytest.raises(ValueError):
with (
patch(
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
return_value=("wav", wav_bytes),
),
pytest.raises(ValueError),
):
voice_assistant_udp_server_v2.started = True
voice_assistant_udp_server_v2.transport = Mock(spec=asyncio.DatagramTransport)
voice_assistant_udp_server_v2._event_callback(
@ -431,10 +450,14 @@ async def test_send_tts_wrong_format(
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
) -> None:
"""Test that only WAV audio will be streamed."""
with patch(
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
return_value=("raw", bytes(1024)),
), pytest.raises(ValueError):
with (
patch(
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
return_value=("raw", bytes(1024)),
),
pytest.raises(ValueError),
):
voice_assistant_udp_server_v2.started = True
voice_assistant_udp_server_v2.transport = Mock(spec=asyncio.DatagramTransport)
voice_assistant_udp_server_v2._event_callback(
@ -450,6 +473,33 @@ async def test_send_tts_wrong_format(
await voice_assistant_udp_server_v2._tts_task # raises ValueError
async def test_send_tts_not_started(
hass: HomeAssistant,
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
test_wav,
) -> None:
"""Test the UDP server does not call sendto when not started."""
with patch(
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
return_value=("wav", test_wav),
):
voice_assistant_udp_server_v2.started = False
voice_assistant_udp_server_v2.transport = Mock(spec=asyncio.DatagramTransport)
voice_assistant_udp_server_v2._event_callback(
PipelineEvent(
type=PipelineEventType.TTS_END,
data={
"tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL}
},
)
)
await voice_assistant_udp_server_v2._tts_done.wait()
voice_assistant_udp_server_v2.transport.sendto.assert_not_called()
async def test_wake_word(
hass: HomeAssistant,
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
@ -459,11 +509,12 @@ async def test_wake_word(
async def async_pipeline_from_audio_stream(*args, start_stage, **kwargs):
assert start_stage == PipelineStage.WAKE_WORD
with patch(
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
), patch(
"asyncio.Event.wait" # TTS wait event
with (
patch(
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch("asyncio.Event.wait"), # TTS wait event
):
voice_assistant_udp_server_v2.transport = Mock()
@ -515,10 +566,15 @@ async def test_wake_word_abort_exception(
async def async_pipeline_from_audio_stream(*args, **kwargs):
raise WakeWordDetectionAborted
with patch(
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
), patch.object(voice_assistant_udp_server_v2, "handle_event") as mock_handle_event:
with (
patch(
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch.object(
voice_assistant_udp_server_v2, "handle_event"
) as mock_handle_event,
):
voice_assistant_udp_server_v2.transport = Mock()
await voice_assistant_udp_server_v2.run_pipeline(