Set pipeline_id from pipeline select (#92085)

This commit is contained in:
Jesse Hills 2023-04-27 10:29:08 +12:00 committed by GitHub
parent fdfd567ee5
commit 29ca43acf6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 47 additions and 20 deletions

View File

@ -301,7 +301,7 @@ async def async_setup_entry( # noqa: C901
if voice_assistant_udp_server is not None: if voice_assistant_udp_server is not None:
return None return None
voice_assistant_udp_server = VoiceAssistantUDPServer(hass) voice_assistant_udp_server = VoiceAssistantUDPServer(hass, entry_data)
port = await voice_assistant_udp_server.start_server() port = await voice_assistant_udp_server.start_server()
hass.async_create_background_task( hass.async_create_background_task(

View File

@ -14,10 +14,13 @@ from homeassistant.components.assist_pipeline import (
PipelineEvent, PipelineEvent,
PipelineEventType, PipelineEventType,
async_pipeline_from_audio_stream, async_pipeline_from_audio_stream,
select as pipeline_select,
) )
from homeassistant.components.media_player import async_process_play_media_url from homeassistant.components.media_player import async_process_play_media_url
from homeassistant.core import Context, HomeAssistant, callback from homeassistant.core import Context, HomeAssistant, callback
from .const import DOMAIN
from .entry_data import RuntimeEntryData
from .enum_mapper import EsphomeEnumMapper from .enum_mapper import EsphomeEnumMapper
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -48,10 +51,18 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
queue: asyncio.Queue[bytes] | None = None queue: asyncio.Queue[bytes] | None = None
transport: asyncio.DatagramTransport | None = None transport: asyncio.DatagramTransport | None = None
def __init__(self, hass: HomeAssistant) -> None: def __init__(
self,
hass: HomeAssistant,
entry_data: RuntimeEntryData,
) -> None:
"""Initialize UDP receiver.""" """Initialize UDP receiver."""
self.context = Context() self.context = Context()
self.hass = hass self.hass = hass
assert entry_data.device_info is not None
self.device_info = entry_data.device_info
self.queue = asyncio.Queue() self.queue = asyncio.Queue()
async def start_server(self) -> int: async def start_server(self) -> int:
@ -155,7 +166,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
context=self.context, context=self.context,
event_callback=handle_pipeline_event, event_callback=handle_pipeline_event,
stt_metadata=stt.SpeechMetadata( stt_metadata=stt.SpeechMetadata(
language="", language="", # set in async_pipeline_from_audio_stream
format=stt.AudioFormats.WAV, format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM, codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16, bit_rate=stt.AudioBitRates.BITRATE_16,
@ -163,4 +174,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
channel=stt.AudioChannels.CHANNEL_MONO, channel=stt.AudioChannels.CHANNEL_MONO,
), ),
stt_stream=self._iterate_packets(), stt_stream=self._iterate_packets(),
pipeline_id=pipeline_select.get_chosen_pipeline(
self.hass, DOMAIN, self.device_info.mac_address
),
) )

View File

@ -8,6 +8,8 @@ import async_timeout
import pytest import pytest
from homeassistant.components import assist_pipeline, esphome from homeassistant.components import assist_pipeline, esphome
from homeassistant.components.esphome import DomainData
from homeassistant.components.esphome.voice_assistant import VoiceAssistantUDPServer
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
_TEST_INPUT_TEXT = "This is an input test" _TEST_INPUT_TEXT = "This is an input test"
@ -15,7 +17,19 @@ _TEST_OUTPUT_TEXT = "This is an output test"
_TEST_OUTPUT_URL = "output.mp3" _TEST_OUTPUT_URL = "output.mp3"
async def test_pipeline_events(hass: HomeAssistant) -> None: @pytest.fixture
def voice_assistant_udp_server_v1(
hass: HomeAssistant,
mock_voice_assistant_v1_entry,
) -> VoiceAssistantUDPServer:
"""Return the UDP server."""
entry_data = DomainData.get(hass).get_entry_data(mock_voice_assistant_v1_entry)
return VoiceAssistantUDPServer(hass, entry_data)
async def test_pipeline_events(
hass: HomeAssistant, voice_assistant_udp_server_v1: VoiceAssistantUDPServer
) -> None:
"""Test that the pipeline function is called.""" """Test that the pipeline function is called."""
async def async_pipeline_from_audio_stream(*args, **kwargs): async def async_pipeline_from_audio_stream(*args, **kwargs):
@ -67,75 +81,74 @@ async def test_pipeline_events(hass: HomeAssistant) -> None:
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream", "homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream, new=async_pipeline_from_audio_stream,
): ):
server = esphome.voice_assistant.VoiceAssistantUDPServer(hass) voice_assistant_udp_server_v1.transport = Mock()
server.transport = Mock()
await server.run_pipeline(handle_event) await voice_assistant_udp_server_v1.run_pipeline(handle_event)
async def test_udp_server( async def test_udp_server(
hass: HomeAssistant, hass: HomeAssistant,
socket_enabled, socket_enabled,
unused_udp_port_factory, unused_udp_port_factory,
voice_assistant_udp_server_v1: VoiceAssistantUDPServer,
) -> None: ) -> None:
"""Test the UDP server runs and queues incoming data.""" """Test the UDP server runs and queues incoming data."""
port_to_use = unused_udp_port_factory() port_to_use = unused_udp_port_factory()
server = esphome.voice_assistant.VoiceAssistantUDPServer(hass)
with patch( with patch(
"homeassistant.components.esphome.voice_assistant.UDP_PORT", new=port_to_use "homeassistant.components.esphome.voice_assistant.UDP_PORT", new=port_to_use
): ):
port = await server.start_server() port = await voice_assistant_udp_server_v1.start_server()
assert port == port_to_use assert port == port_to_use
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
assert server.queue.qsize() == 0 assert voice_assistant_udp_server_v1.queue.qsize() == 0
sock.sendto(b"test", ("127.0.0.1", port)) sock.sendto(b"test", ("127.0.0.1", port))
# Give the socket some time to send/receive the data # Give the socket some time to send/receive the data
async with async_timeout.timeout(1): async with async_timeout.timeout(1):
while server.queue.qsize() == 0: while voice_assistant_udp_server_v1.queue.qsize() == 0:
await asyncio.sleep(0.1) await asyncio.sleep(0.1)
assert server.queue.qsize() == 1 assert voice_assistant_udp_server_v1.queue.qsize() == 1
server.stop() voice_assistant_udp_server_v1.stop()
assert server.transport.is_closing() assert voice_assistant_udp_server_v1.transport.is_closing()
async def test_udp_server_multiple( async def test_udp_server_multiple(
hass: HomeAssistant, hass: HomeAssistant,
socket_enabled, socket_enabled,
unused_udp_port_factory, unused_udp_port_factory,
voice_assistant_udp_server_v1: VoiceAssistantUDPServer,
) -> None: ) -> None:
"""Test that the UDP server raises an error if started twice.""" """Test that the UDP server raises an error if started twice."""
server = esphome.voice_assistant.VoiceAssistantUDPServer(hass)
with patch( with patch(
"homeassistant.components.esphome.voice_assistant.UDP_PORT", "homeassistant.components.esphome.voice_assistant.UDP_PORT",
new=unused_udp_port_factory(), new=unused_udp_port_factory(),
): ):
await server.start_server() await voice_assistant_udp_server_v1.start_server()
with patch( with patch(
"homeassistant.components.esphome.voice_assistant.UDP_PORT", "homeassistant.components.esphome.voice_assistant.UDP_PORT",
new=unused_udp_port_factory(), new=unused_udp_port_factory(),
), pytest.raises(RuntimeError): ), pytest.raises(RuntimeError):
pass pass
await server.start_server() await voice_assistant_udp_server_v1.start_server()
async def test_udp_server_after_stopped( async def test_udp_server_after_stopped(
hass: HomeAssistant, hass: HomeAssistant,
socket_enabled, socket_enabled,
unused_udp_port_factory, unused_udp_port_factory,
voice_assistant_udp_server_v1: VoiceAssistantUDPServer,
) -> None: ) -> None:
"""Test that the UDP server raises an error if started after stopped.""" """Test that the UDP server raises an error if started after stopped."""
server = esphome.voice_assistant.VoiceAssistantUDPServer(hass) voice_assistant_udp_server_v1.stop()
server.stop()
with patch( with patch(
"homeassistant.components.esphome.voice_assistant.UDP_PORT", "homeassistant.components.esphome.voice_assistant.UDP_PORT",
new=unused_udp_port_factory(), new=unused_udp_port_factory(),
), pytest.raises(RuntimeError): ), pytest.raises(RuntimeError):
await server.start_server() await voice_assistant_udp_server_v1.start_server()