mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 21:27:38 +00:00
More thorough checks in ESPHome voice assistant UDP server (#109394)
* More thorough checks in UDP server * Simplify and change to stop_requested * Check transport
This commit is contained in:
parent
ae210886c1
commit
3347a3f8a6
@ -352,7 +352,6 @@ class ESPHomeManager:
|
||||
if self.voice_assistant_udp_server is not None:
|
||||
_LOGGER.warning("Voice assistant UDP server was not stopped")
|
||||
self.voice_assistant_udp_server.stop()
|
||||
self.voice_assistant_udp_server.close()
|
||||
self.voice_assistant_udp_server = None
|
||||
|
||||
hass = self.hass
|
||||
|
@ -1,4 +1,5 @@
|
||||
"""ESPHome voice assistant support."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
@ -67,7 +68,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
||||
"""Receive UDP packets and forward them to the voice assistant."""
|
||||
|
||||
started = False
|
||||
stopped = False
|
||||
stop_requested = False
|
||||
transport: asyncio.DatagramTransport | None = None
|
||||
remote_addr: tuple[str, int] | None = None
|
||||
|
||||
@ -92,6 +93,11 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
||||
self._tts_done = asyncio.Event()
|
||||
self._tts_task: asyncio.Task | None = None
|
||||
|
||||
@property
|
||||
def is_running(self) -> bool:
|
||||
"""True if the the UDP server is started and hasn't been asked to stop."""
|
||||
return self.started and (not self.stop_requested)
|
||||
|
||||
async def start_server(self) -> int:
|
||||
"""Start accepting connections."""
|
||||
|
||||
@ -99,7 +105,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
||||
"""Accept connection."""
|
||||
if self.started:
|
||||
raise RuntimeError("Can only start once")
|
||||
if self.stopped:
|
||||
if self.stop_requested:
|
||||
raise RuntimeError("No longer accepting connections")
|
||||
|
||||
self.started = True
|
||||
@ -124,7 +130,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
||||
@callback
|
||||
def datagram_received(self, data: bytes, addr: tuple[str, int]) -> None:
|
||||
"""Handle incoming UDP packet."""
|
||||
if not self.started or self.stopped:
|
||||
if not self.is_running:
|
||||
return
|
||||
if self.remote_addr is None:
|
||||
self.remote_addr = addr
|
||||
@ -142,19 +148,19 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
||||
def stop(self) -> None:
|
||||
"""Stop the receiver."""
|
||||
self.queue.put_nowait(b"")
|
||||
self.started = False
|
||||
self.stopped = True
|
||||
self.close()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the receiver."""
|
||||
self.started = False
|
||||
self.stopped = True
|
||||
self.stop_requested = True
|
||||
|
||||
if self.transport is not None:
|
||||
self.transport.close()
|
||||
|
||||
async def _iterate_packets(self) -> AsyncIterable[bytes]:
|
||||
"""Iterate over incoming packets."""
|
||||
if not self.started or self.stopped:
|
||||
if not self.is_running:
|
||||
raise RuntimeError("Not running")
|
||||
|
||||
while data := await self.queue.get():
|
||||
@ -303,8 +309,11 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
||||
|
||||
async def _send_tts(self, media_id: str) -> None:
|
||||
"""Send TTS audio to device via UDP."""
|
||||
# Always send stream start/end events
|
||||
self.handle_event(VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START, {})
|
||||
|
||||
try:
|
||||
if self.transport is None:
|
||||
if (not self.is_running) or (self.transport is None):
|
||||
return
|
||||
|
||||
extension, data = await tts.async_get_media_source_audio(
|
||||
@ -337,15 +346,11 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
||||
|
||||
_LOGGER.debug("Sending %d bytes of audio", audio_bytes_size)
|
||||
|
||||
self.handle_event(
|
||||
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START, {}
|
||||
)
|
||||
|
||||
bytes_per_sample = stt.AudioBitRates.BITRATE_16 // 8
|
||||
sample_offset = 0
|
||||
samples_left = audio_bytes_size // bytes_per_sample
|
||||
|
||||
while samples_left > 0:
|
||||
while (samples_left > 0) and self.is_running:
|
||||
bytes_offset = sample_offset * bytes_per_sample
|
||||
chunk: bytes = audio_bytes[bytes_offset : bytes_offset + 1024]
|
||||
samples_in_chunk = len(chunk) // bytes_per_sample
|
||||
|
@ -70,6 +70,19 @@ def voice_assistant_udp_server_v2(
|
||||
return voice_assistant_udp_server(entry=mock_voice_assistant_v2_entry)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_wav() -> bytes:
|
||||
"""Return one second of empty WAV audio."""
|
||||
with io.BytesIO() as wav_io:
|
||||
with wave.open(wav_io, "wb") as wav_file:
|
||||
wav_file.setframerate(16000)
|
||||
wav_file.setsampwidth(2)
|
||||
wav_file.setnchannels(1)
|
||||
wav_file.writeframes(bytes(_ONE_SECOND))
|
||||
|
||||
return wav_io.getvalue()
|
||||
|
||||
|
||||
async def test_pipeline_events(
|
||||
hass: HomeAssistant,
|
||||
voice_assistant_udp_server_v1: VoiceAssistantUDPServer,
|
||||
@ -241,11 +254,13 @@ async def test_udp_server_multiple(
|
||||
):
|
||||
await voice_assistant_udp_server_v1.start_server()
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.esphome.voice_assistant.UDP_PORT",
|
||||
new=unused_udp_port_factory(),
|
||||
), pytest.raises(RuntimeError):
|
||||
pass
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.esphome.voice_assistant.UDP_PORT",
|
||||
new=unused_udp_port_factory(),
|
||||
),
|
||||
pytest.raises(RuntimeError),
|
||||
):
|
||||
await voice_assistant_udp_server_v1.start_server()
|
||||
|
||||
|
||||
@ -257,10 +272,13 @@ async def test_udp_server_after_stopped(
|
||||
) -> None:
|
||||
"""Test that the UDP server raises an error if started after stopped."""
|
||||
voice_assistant_udp_server_v1.close()
|
||||
with patch(
|
||||
"homeassistant.components.esphome.voice_assistant.UDP_PORT",
|
||||
new=unused_udp_port_factory(),
|
||||
), pytest.raises(RuntimeError):
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.esphome.voice_assistant.UDP_PORT",
|
||||
new=unused_udp_port_factory(),
|
||||
),
|
||||
pytest.raises(RuntimeError),
|
||||
):
|
||||
await voice_assistant_udp_server_v1.start_server()
|
||||
|
||||
|
||||
@ -362,35 +380,33 @@ async def test_send_tts_not_called_when_empty(
|
||||
async def test_send_tts(
|
||||
hass: HomeAssistant,
|
||||
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
|
||||
test_wav,
|
||||
) -> None:
|
||||
"""Test the UDP server calls sendto to transmit audio data to device."""
|
||||
with io.BytesIO() as wav_io:
|
||||
with wave.open(wav_io, "wb") as wav_file:
|
||||
wav_file.setframerate(16000)
|
||||
wav_file.setsampwidth(2)
|
||||
wav_file.setnchannels(1)
|
||||
wav_file.writeframes(bytes(_ONE_SECOND))
|
||||
|
||||
wav_bytes = wav_io.getvalue()
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
|
||||
return_value=("wav", wav_bytes),
|
||||
return_value=("wav", test_wav),
|
||||
):
|
||||
voice_assistant_udp_server_v2.started = True
|
||||
voice_assistant_udp_server_v2.transport = Mock(spec=asyncio.DatagramTransport)
|
||||
|
||||
voice_assistant_udp_server_v2._event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.TTS_END,
|
||||
data={
|
||||
"tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL}
|
||||
},
|
||||
with patch.object(
|
||||
voice_assistant_udp_server_v2.transport, "is_closing", return_value=False
|
||||
):
|
||||
voice_assistant_udp_server_v2._event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.TTS_END,
|
||||
data={
|
||||
"tts_output": {
|
||||
"media_id": _TEST_MEDIA_ID,
|
||||
"url": _TEST_OUTPUT_URL,
|
||||
}
|
||||
},
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
await voice_assistant_udp_server_v2._tts_done.wait()
|
||||
await voice_assistant_udp_server_v2._tts_done.wait()
|
||||
|
||||
voice_assistant_udp_server_v2.transport.sendto.assert_called()
|
||||
voice_assistant_udp_server_v2.transport.sendto.assert_called()
|
||||
|
||||
|
||||
async def test_send_tts_wrong_sample_rate(
|
||||
@ -400,17 +416,20 @@ async def test_send_tts_wrong_sample_rate(
|
||||
"""Test the UDP server calls sendto to transmit audio data to device."""
|
||||
with io.BytesIO() as wav_io:
|
||||
with wave.open(wav_io, "wb") as wav_file:
|
||||
wav_file.setframerate(22050) # should be 16000
|
||||
wav_file.setframerate(22050)
|
||||
wav_file.setsampwidth(2)
|
||||
wav_file.setnchannels(1)
|
||||
wav_file.writeframes(bytes(_ONE_SECOND))
|
||||
|
||||
wav_bytes = wav_io.getvalue()
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
|
||||
return_value=("wav", wav_bytes),
|
||||
), pytest.raises(ValueError):
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
|
||||
return_value=("wav", wav_bytes),
|
||||
),
|
||||
pytest.raises(ValueError),
|
||||
):
|
||||
voice_assistant_udp_server_v2.started = True
|
||||
voice_assistant_udp_server_v2.transport = Mock(spec=asyncio.DatagramTransport)
|
||||
|
||||
voice_assistant_udp_server_v2._event_callback(
|
||||
@ -431,10 +450,14 @@ async def test_send_tts_wrong_format(
|
||||
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
|
||||
) -> None:
|
||||
"""Test that only WAV audio will be streamed."""
|
||||
with patch(
|
||||
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
|
||||
return_value=("raw", bytes(1024)),
|
||||
), pytest.raises(ValueError):
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
|
||||
return_value=("raw", bytes(1024)),
|
||||
),
|
||||
pytest.raises(ValueError),
|
||||
):
|
||||
voice_assistant_udp_server_v2.started = True
|
||||
voice_assistant_udp_server_v2.transport = Mock(spec=asyncio.DatagramTransport)
|
||||
|
||||
voice_assistant_udp_server_v2._event_callback(
|
||||
@ -450,6 +473,33 @@ async def test_send_tts_wrong_format(
|
||||
await voice_assistant_udp_server_v2._tts_task # raises ValueError
|
||||
|
||||
|
||||
async def test_send_tts_not_started(
|
||||
hass: HomeAssistant,
|
||||
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
|
||||
test_wav,
|
||||
) -> None:
|
||||
"""Test the UDP server does not call sendto when not started."""
|
||||
with patch(
|
||||
"homeassistant.components.esphome.voice_assistant.tts.async_get_media_source_audio",
|
||||
return_value=("wav", test_wav),
|
||||
):
|
||||
voice_assistant_udp_server_v2.started = False
|
||||
voice_assistant_udp_server_v2.transport = Mock(spec=asyncio.DatagramTransport)
|
||||
|
||||
voice_assistant_udp_server_v2._event_callback(
|
||||
PipelineEvent(
|
||||
type=PipelineEventType.TTS_END,
|
||||
data={
|
||||
"tts_output": {"media_id": _TEST_MEDIA_ID, "url": _TEST_OUTPUT_URL}
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
await voice_assistant_udp_server_v2._tts_done.wait()
|
||||
|
||||
voice_assistant_udp_server_v2.transport.sendto.assert_not_called()
|
||||
|
||||
|
||||
async def test_wake_word(
|
||||
hass: HomeAssistant,
|
||||
voice_assistant_udp_server_v2: VoiceAssistantUDPServer,
|
||||
@ -459,11 +509,12 @@ async def test_wake_word(
|
||||
async def async_pipeline_from_audio_stream(*args, start_stage, **kwargs):
|
||||
assert start_stage == PipelineStage.WAKE_WORD
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
), patch(
|
||||
"asyncio.Event.wait" # TTS wait event
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
),
|
||||
patch("asyncio.Event.wait"), # TTS wait event
|
||||
):
|
||||
voice_assistant_udp_server_v2.transport = Mock()
|
||||
|
||||
@ -515,10 +566,15 @@ async def test_wake_word_abort_exception(
|
||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||
raise WakeWordDetectionAborted
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
), patch.object(voice_assistant_udp_server_v2, "handle_event") as mock_handle_event:
|
||||
with (
|
||||
patch(
|
||||
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
|
||||
new=async_pipeline_from_audio_stream,
|
||||
),
|
||||
patch.object(
|
||||
voice_assistant_udp_server_v2, "handle_event"
|
||||
) as mock_handle_event,
|
||||
):
|
||||
voice_assistant_udp_server_v2.transport = Mock()
|
||||
|
||||
await voice_assistant_udp_server_v2.run_pipeline(
|
||||
|
Loading…
x
Reference in New Issue
Block a user