More thorough checks in ESPHome voice assistant UDP server (#109394)

* More thorough checks in UDP server

* Simplify and change to stop_requested

* Check transport
This commit is contained in:
Michael Hansen 2024-02-02 20:26:44 -06:00 committed by GitHub
parent ae210886c1
commit 3347a3f8a6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 122 additions and 62 deletions

View File

@ -352,7 +352,6 @@ class ESPHomeManager:
if self.voice_assistant_udp_server is not None: if self.voice_assistant_udp_server is not None:
_LOGGER.warning("Voice assistant UDP server was not stopped") _LOGGER.warning("Voice assistant UDP server was not stopped")
self.voice_assistant_udp_server.stop() self.voice_assistant_udp_server.stop()
self.voice_assistant_udp_server.close()
self.voice_assistant_udp_server = None self.voice_assistant_udp_server = None
hass = self.hass hass = self.hass

View File

@ -1,4 +1,5 @@
"""ESPHome voice assistant support.""" """ESPHome voice assistant support."""
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
@ -67,7 +68,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
"""Receive UDP packets and forward them to the voice assistant.""" """Receive UDP packets and forward them to the voice assistant."""
started = False started = False
stopped = False stop_requested = False
transport: asyncio.DatagramTransport | None = None transport: asyncio.DatagramTransport | None = None
remote_addr: tuple[str, int] | None = None remote_addr: tuple[str, int] | None = None
@ -92,6 +93,11 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
self._tts_done = asyncio.Event() self._tts_done = asyncio.Event()
self._tts_task: asyncio.Task | None = None self._tts_task: asyncio.Task | None = None
@property
def is_running(self) -> bool:
"""True if the the UDP server is started and hasn't been asked to stop."""
return self.started and (not self.stop_requested)
async def start_server(self) -> int: async def start_server(self) -> int:
"""Start accepting connections.""" """Start accepting connections."""
@ -99,7 +105,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
"""Accept connection.""" """Accept connection."""
if self.started: if self.started:
raise RuntimeError("Can only start once") raise RuntimeError("Can only start once")
if self.stopped: if self.stop_requested:
raise RuntimeError("No longer accepting connections") raise RuntimeError("No longer accepting connections")
self.started = True self.started = True
@ -124,7 +130,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
@callback @callback
def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None: def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None:
"""Handle incoming UDP packet.""" """Handle incoming UDP packet."""
if not self.started or self.stopped: if not self.is_running:
return return
if self.remote_addr is None: if self.remote_addr is None:
self.remote_addr = addr self.remote_addr = addr
@ -142,19 +148,19 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
def stop(self) -> None: def stop(self) -> None:
"""Stop the receiver.""" """Stop the receiver."""
self.queue.put_nowait(b"") self.queue.put_nowait(b"")
self.started = False self.close()
self.stopped = True
def close(self) -> None: def close(self) -> None:
"""Close the receiver.""" """Close the receiver."""
self.started = False self.started = False
self.stopped = True self.stop_requested = True
if self.transport is not None: if self.transport is not None:
self.transport.close() self.transport.close()
async def _iterate_packets(self) -> AsyncIterable[bytes]: async def _iterate_packets(self) -> AsyncIterable[bytes]:
"""Iterate over incoming packets.""" """Iterate over incoming packets."""
if not self.started or self.stopped: if not self.is_running:
raise RuntimeError("Not running") raise RuntimeError("Not running")
while data := await self.queue.get(): while data := await self.queue.get():
@ -303,8 +309,11 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
async def _send_tts(self, media_id: str) -> None: async def _send_tts(self, media_id: str) -> None:
"""Send TTS audio to device via UDP.""" """Send TTS audio to device via UDP."""
# Always send stream start/end events
self.handle_event(VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START, {})
try: try:
if self.transport is None: if (not self.is_running) or (self.transport is None):
return return
extension, data = await tts.async_get_media_source_audio( extension, data = await tts.async_get_media_source_audio(
@ -337,15 +346,11 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
_LOGGER.debug("Sending %d bytes of audio", audio_bytes_size) _LOGGER.debug("Sending %d bytes of audio", audio_bytes_size)
self.handle_event(
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START, {}
)
bytes_per_sample = stt.AudioBitRates.BITRATE_16 // 8 bytes_per_sample = stt.AudioBitRates.BITRATE_16 // 8
sample_offset = 0 sample_offset = 0
samples_left = audio_bytes_size // bytes_per_sample samples_left = audio_bytes_size // bytes_per_sample
while samples_left > 0: while (samples_left > 0) and self.is_running:
bytes_offset = sample_offset * bytes_per_sample bytes_offset = sample_offset * bytes_per_sample
chunk: bytes = audio_bytes[bytes_offset : bytes_offset + 1024] chunk: bytes = audio_bytes[bytes_offset : bytes_offset + 1024]
samples_in_chunk = len(chunk) // bytes_per_sample samples_in_chunk = len(chunk) // bytes_per_sample

View File

@ -70,6 +70,19 @@ def voice_assistant_udp_server_v2(
return voice_assistant_udp_server(entry=mock_voice_assistant_v2_entry) return voice_assistant_udp_server(entry=mock_voice_assistant_v2_entry)
@pytest.fixture
def test_wav() -> bytes:
"""Return one second of empty WAV audio."""
with io.BytesIO() as wav_io:
with wave.open(wav_io, "wb") as wav_file:
wav_file.setframerate(16000)
wav_file.setsampwidth(2)
wav_file.setnchannels(1)
wav_file.writeframes(bytes(_ONE_SECOND))
return wav_io.getvalue()
async def test_pipeline_events( async def test_pipeline_events(
hass: HomeAssistant, hass: HomeAssistant,
voice_assistant_udp_server_v1: VoiceAssistantUDPServer, voice_assistant_udp_server_v1: VoiceAssistantUDPServer,
@ -241,11 +254,13 @@ async def test_udp_server_multiple(
): ):
await voice_assistant_udp_server_v1.start_server() await voice_assistant_udp_server_v1.start_server()
with patch( with (
"homeassistant.components.esphome.voice_assistant.UDP_PORT", patch(
new=unused_udp_port_factory(), "homeassistant.components.esphome.voice_assistant.UDP_PORT",
), pytest.raises(RuntimeError): new=unused_udp_port_factory(),
pass ),
pytest.raises(RuntimeError),
):
await voice_assistant_udp_server_v1.start_server() await voice_assistant_udp_server_v1.start_server()
@ -257,10 +272,13 @@ async def test_udp_server_after_stopped(
) -> None: ) -> None:
"""Test that the UDP server raises an error if started after stopped.""" """Test that the UDP server raises an error if started after stopped."""
voice_assistant_udp_server_v1.close() voice_assistant_udp_server_v1.close()
with patch( with (
"homeassistant.components.esphome.voice_assistant.UDP_PORT", patch(
new=unused_udp_port_factory(), "homeassistant.components.esphome.voice_assistant.UDP_PORT",
), pytest.raises(RuntimeError): new=unused_udp_port_factory(),
),
pytest.raises(RuntimeError),
):
await voice_assistant_udp_server_v1.start_server() await voice_assistant_udp_server_v1.start_server()
@ -362,35 +380,33 @@ async def test_send_tts_not_called_when_empty(
async def test_send_tts( async def test_send_tts(
hass: HomeAssistant, hass: HomeAssistant,
voice_assistant_udp_server_v2: VoiceAssistantUDPServer, voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
test_wav,
) -> None: ) -> None:
"""Test the UDP server calls sendto to transmit audio data to device.""" """Test the UDP server calls sendto to transmit audio data to device."""
with io.BytesIO() as wav_io:
with wave.open(wav_io, "wb") as wav_file:
wav_file.setframerate(16000)
wav_file.setsampwidth(2)
wav_file.setnchannels(1)
wav_file.writeframes(bytes(_ONE_SECOND))
wav_bytes = wav_io.getvalue()
with patch( with patch(
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio", "homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
return_value=("wav", wav_bytes), return_value=("wav", test_wav),
): ):
voice_assistant_udp_server_v2.started = True
voice_assistant_udp_server_v2.transport = Mock(spec=asyncio.DatagramTransport) voice_assistant_udp_server_v2.transport = Mock(spec=asyncio.DatagramTransport)
with patch.object(
voice_assistant_udp_server_v2._event_callback( voice_assistant_udp_server_v2.transport, "is_closing", return_value=False
PipelineEvent( ):
type=PipelineEventType.TTS_END, voice_assistant_udp_server_v2._event_callback(
data={ PipelineEvent(
"tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL} type=PipelineEventType.TTS_END,
}, data={
"tts_output": {
"media_id": _TEST_MEDIA_ID,
"url": _TEST_OUTPUT_URL,
}
},
)
) )
)
await voice_assistant_udp_server_v2._tts_done.wait() await voice_assistant_udp_server_v2._tts_done.wait()
voice_assistant_udp_server_v2.transport.sendto.assert_called() voice_assistant_udp_server_v2.transport.sendto.assert_called()
async def test_send_tts_wrong_sample_rate( async def test_send_tts_wrong_sample_rate(
@ -400,17 +416,20 @@ async def test_send_tts_wrong_sample_rate(
"""Test the UDP server calls sendto to transmit audio data to device.""" """Test the UDP server calls sendto to transmit audio data to device."""
with io.BytesIO() as wav_io: with io.BytesIO() as wav_io:
with wave.open(wav_io, "wb") as wav_file: with wave.open(wav_io, "wb") as wav_file:
wav_file.setframerate(22050) # should be 16000 wav_file.setframerate(22050)
wav_file.setsampwidth(2) wav_file.setsampwidth(2)
wav_file.setnchannels(1) wav_file.setnchannels(1)
wav_file.writeframes(bytes(_ONE_SECOND)) wav_file.writeframes(bytes(_ONE_SECOND))
wav_bytes = wav_io.getvalue() wav_bytes = wav_io.getvalue()
with (
with patch( patch(
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio", "homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
return_value=("wav", wav_bytes), return_value=("wav", wav_bytes),
), pytest.raises(ValueError): ),
pytest.raises(ValueError),
):
voice_assistant_udp_server_v2.started = True
voice_assistant_udp_server_v2.transport = Mock(spec=asyncio.DatagramTransport) voice_assistant_udp_server_v2.transport = Mock(spec=asyncio.DatagramTransport)
voice_assistant_udp_server_v2._event_callback( voice_assistant_udp_server_v2._event_callback(
@ -431,10 +450,14 @@ async def test_send_tts_wrong_format(
voice_assistant_udp_server_v2: VoiceAssistantUDPServer, voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
) -> None: ) -> None:
"""Test that only WAV audio will be streamed.""" """Test that only WAV audio will be streamed."""
with patch( with (
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio", patch(
return_value=("raw", bytes(1024)), "homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
), pytest.raises(ValueError): return_value=("raw", bytes(1024)),
),
pytest.raises(ValueError),
):
voice_assistant_udp_server_v2.started = True
voice_assistant_udp_server_v2.transport = Mock(spec=asyncio.DatagramTransport) voice_assistant_udp_server_v2.transport = Mock(spec=asyncio.DatagramTransport)
voice_assistant_udp_server_v2._event_callback( voice_assistant_udp_server_v2._event_callback(
@ -450,6 +473,33 @@ async def test_send_tts_wrong_format(
await voice_assistant_udp_server_v2._tts_task # raises ValueError await voice_assistant_udp_server_v2._tts_task # raises ValueError
async def test_send_tts_not_started(
hass: HomeAssistant,
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
test_wav,
) -> None:
"""Test the UDP server does not call sendto when not started."""
with patch(
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
return_value=("wav", test_wav),
):
voice_assistant_udp_server_v2.started = False
voice_assistant_udp_server_v2.transport = Mock(spec=asyncio.DatagramTransport)
voice_assistant_udp_server_v2._event_callback(
PipelineEvent(
type=PipelineEventType.TTS_END,
data={
"tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL}
},
)
)
await voice_assistant_udp_server_v2._tts_done.wait()
voice_assistant_udp_server_v2.transport.sendto.assert_not_called()
async def test_wake_word( async def test_wake_word(
hass: HomeAssistant, hass: HomeAssistant,
voice_assistant_udp_server_v2: VoiceAssistantUDPServer, voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
@ -459,11 +509,12 @@ async def test_wake_word(
async def async_pipeline_from_audio_stream(*args, start_stage, **kwargs): async def async_pipeline_from_audio_stream(*args, start_stage, **kwargs):
assert start_stage == PipelineStage.WAKE_WORD assert start_stage == PipelineStage.WAKE_WORD
with patch( with (
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream", patch(
new=async_pipeline_from_audio_stream, "homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
), patch( new=async_pipeline_from_audio_stream,
"asyncio.Event.wait" # TTS wait event ),
patch("asyncio.Event.wait"), # TTS wait event
): ):
voice_assistant_udp_server_v2.transport = Mock() voice_assistant_udp_server_v2.transport = Mock()
@ -515,10 +566,15 @@ async def test_wake_word_abort_exception(
async def async_pipeline_from_audio_stream(*args, **kwargs): async def async_pipeline_from_audio_stream(*args, **kwargs):
raise WakeWordDetectionAborted raise WakeWordDetectionAborted
with patch( with (
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream", patch(
new=async_pipeline_from_audio_stream, "homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
), patch.object(voice_assistant_udp_server_v2, "handle_event") as mock_handle_event: new=async_pipeline_from_audio_stream,
),
patch.object(
voice_assistant_udp_server_v2, "handle_event"
) as mock_handle_event,
):
voice_assistant_udp_server_v2.transport = Mock() voice_assistant_udp_server_v2.transport = Mock()
await voice_assistant_udp_server_v2.run_pipeline( await voice_assistant_udp_server_v2.run_pipeline(