diff --git a/homeassistant/components/esphome/__init__.py b/homeassistant/components/esphome/__init__.py index 3bec7f883d0..6ce5f656d6e 100644 --- a/homeassistant/components/esphome/__init__.py +++ b/homeassistant/components/esphome/__init__.py @@ -301,7 +301,7 @@ async def async_setup_entry( # noqa: C901 if voice_assistant_udp_server is not None: return None - voice_assistant_udp_server = VoiceAssistantUDPServer(hass) + voice_assistant_udp_server = VoiceAssistantUDPServer(hass, entry_data) port = await voice_assistant_udp_server.start_server() hass.async_create_background_task( diff --git a/homeassistant/components/esphome/voice_assistant.py b/homeassistant/components/esphome/voice_assistant.py index 9b35fc7972a..b6c76e00f4c 100644 --- a/homeassistant/components/esphome/voice_assistant.py +++ b/homeassistant/components/esphome/voice_assistant.py @@ -14,10 +14,13 @@ from homeassistant.components.assist_pipeline import ( PipelineEvent, PipelineEventType, async_pipeline_from_audio_stream, + select as pipeline_select, ) from homeassistant.components.media_player import async_process_play_media_url from homeassistant.core import Context, HomeAssistant, callback +from .const import DOMAIN +from .entry_data import RuntimeEntryData from .enum_mapper import EsphomeEnumMapper _LOGGER = logging.getLogger(__name__) @@ -48,10 +51,18 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol): queue: asyncio.Queue[bytes] | None = None transport: asyncio.DatagramTransport | None = None - def __init__(self, hass: HomeAssistant) -> None: + def __init__( + self, + hass: HomeAssistant, + entry_data: RuntimeEntryData, + ) -> None: """Initialize UDP receiver.""" self.context = Context() self.hass = hass + + assert entry_data.device_info is not None + self.device_info = entry_data.device_info + self.queue = asyncio.Queue() async def start_server(self) -> int: @@ -155,7 +166,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol): context=self.context, event_callback=handle_pipeline_event, stt_metadata=stt.SpeechMetadata( - language="", + language="", # set in async_pipeline_from_audio_stream format=stt.AudioFormats.WAV, codec=stt.AudioCodecs.PCM, bit_rate=stt.AudioBitRates.BITRATE_16, @@ -163,4 +174,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol): channel=stt.AudioChannels.CHANNEL_MONO, ), stt_stream=self._iterate_packets(), + pipeline_id=pipeline_select.get_chosen_pipeline( + self.hass, DOMAIN, self.device_info.mac_address + ), ) diff --git a/tests/components/esphome/test_voice_assistant.py b/tests/components/esphome/test_voice_assistant.py index ee6f4f7289f..e1fe41829c2 100644 --- a/tests/components/esphome/test_voice_assistant.py +++ b/tests/components/esphome/test_voice_assistant.py @@ -8,6 +8,8 @@ import async_timeout import pytest from homeassistant.components import assist_pipeline, esphome +from homeassistant.components.esphome import DomainData +from homeassistant.components.esphome.voice_assistant import VoiceAssistantUDPServer from homeassistant.core import HomeAssistant _TEST_INPUT_TEXT = "This is an input test" @@ -15,7 +17,19 @@ _TEST_OUTPUT_TEXT = "This is an output test" _TEST_OUTPUT_URL = "output.mp3" -async def test_pipeline_events(hass: HomeAssistant) -> None: +@pytest.fixture +def voice_assistant_udp_server_v1( + hass: HomeAssistant, + mock_voice_assistant_v1_entry, +) -> VoiceAssistantUDPServer: + """Return the UDP server.""" + entry_data = DomainData.get(hass).get_entry_data(mock_voice_assistant_v1_entry) + return VoiceAssistantUDPServer(hass, entry_data) + + +async def test_pipeline_events( + hass: HomeAssistant, voice_assistant_udp_server_v1: VoiceAssistantUDPServer +) -> None: """Test that the pipeline function is called.""" async def async_pipeline_from_audio_stream(*args, **kwargs): @@ -67,75 +81,74 @@ async def test_pipeline_events(hass: HomeAssistant) -> None: "homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream", new=async_pipeline_from_audio_stream, ): - server = esphome.voice_assistant.VoiceAssistantUDPServer(hass) - server.transport = Mock() + voice_assistant_udp_server_v1.transport = Mock() - await server.run_pipeline(handle_event) + await voice_assistant_udp_server_v1.run_pipeline(handle_event) async def test_udp_server( hass: HomeAssistant, socket_enabled, unused_udp_port_factory, + voice_assistant_udp_server_v1: VoiceAssistantUDPServer, ) -> None: """Test the UDP server runs and queues incoming data.""" port_to_use = unused_udp_port_factory() - server = esphome.voice_assistant.VoiceAssistantUDPServer(hass) with patch( "homeassistant.components.esphome.voice_assistant.UDP_PORT", new=port_to_use ): - port = await server.start_server() + port = await voice_assistant_udp_server_v1.start_server() assert port == port_to_use sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - assert server.queue.qsize() == 0 + assert voice_assistant_udp_server_v1.queue.qsize() == 0 sock.sendto(b"test", ("127.0.0.1", port)) # Give the socket some time to send/receive the data async with async_timeout.timeout(1): - while server.queue.qsize() == 0: + while voice_assistant_udp_server_v1.queue.qsize() == 0: await asyncio.sleep(0.1) - assert server.queue.qsize() == 1 + assert voice_assistant_udp_server_v1.queue.qsize() == 1 - server.stop() + voice_assistant_udp_server_v1.stop() - assert server.transport.is_closing() + assert voice_assistant_udp_server_v1.transport.is_closing() async def test_udp_server_multiple( hass: HomeAssistant, socket_enabled, unused_udp_port_factory, + voice_assistant_udp_server_v1: VoiceAssistantUDPServer, ) -> None: """Test that the UDP server raises an error if started twice.""" - server = esphome.voice_assistant.VoiceAssistantUDPServer(hass) with patch( "homeassistant.components.esphome.voice_assistant.UDP_PORT", new=unused_udp_port_factory(), ): - await server.start_server() + await voice_assistant_udp_server_v1.start_server() with patch( "homeassistant.components.esphome.voice_assistant.UDP_PORT", new=unused_udp_port_factory(), ), pytest.raises(RuntimeError): pass - await server.start_server() + await voice_assistant_udp_server_v1.start_server() async def test_udp_server_after_stopped( hass: HomeAssistant, socket_enabled, unused_udp_port_factory, + voice_assistant_udp_server_v1: VoiceAssistantUDPServer, ) -> None: """Test that the UDP server raises an error if started after stopped.""" - server = esphome.voice_assistant.VoiceAssistantUDPServer(hass) - server.stop() + voice_assistant_udp_server_v1.stop() with patch( "homeassistant.components.esphome.voice_assistant.UDP_PORT", new=unused_udp_port_factory(), ), pytest.raises(RuntimeError): - await server.start_server() + await voice_assistant_udp_server_v1.start_server()