mirror of
https://github.com/home-assistant/core.git
synced 2025-07-23 05:07:41 +00:00
Set pipeline_id from pipeline select (#92085)
This commit is contained in:
parent
fdfd567ee5
commit
29ca43acf6
@ -301,7 +301,7 @@ async def async_setup_entry( # noqa: C901
|
|||||||
if voice_assistant_udp_server is not None:
|
if voice_assistant_udp_server is not None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
voice_assistant_udp_server = VoiceAssistantUDPServer(hass)
|
voice_assistant_udp_server = VoiceAssistantUDPServer(hass, entry_data)
|
||||||
port = await voice_assistant_udp_server.start_server()
|
port = await voice_assistant_udp_server.start_server()
|
||||||
|
|
||||||
hass.async_create_background_task(
|
hass.async_create_background_task(
|
||||||
|
@ -14,10 +14,13 @@ from homeassistant.components.assist_pipeline import (
|
|||||||
PipelineEvent,
|
PipelineEvent,
|
||||||
PipelineEventType,
|
PipelineEventType,
|
||||||
async_pipeline_from_audio_stream,
|
async_pipeline_from_audio_stream,
|
||||||
|
select as pipeline_select,
|
||||||
)
|
)
|
||||||
from homeassistant.components.media_player import async_process_play_media_url
|
from homeassistant.components.media_player import async_process_play_media_url
|
||||||
from homeassistant.core import Context, HomeAssistant, callback
|
from homeassistant.core import Context, HomeAssistant, callback
|
||||||
|
|
||||||
|
from .const import DOMAIN
|
||||||
|
from .entry_data import RuntimeEntryData
|
||||||
from .enum_mapper import EsphomeEnumMapper
|
from .enum_mapper import EsphomeEnumMapper
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
@ -48,10 +51,18 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
|||||||
queue: asyncio.Queue[bytes] | None = None
|
queue: asyncio.Queue[bytes] | None = None
|
||||||
transport: asyncio.DatagramTransport | None = None
|
transport: asyncio.DatagramTransport | None = None
|
||||||
|
|
||||||
def __init__(self, hass: HomeAssistant) -> None:
|
def __init__(
|
||||||
|
self,
|
||||||
|
hass: HomeAssistant,
|
||||||
|
entry_data: RuntimeEntryData,
|
||||||
|
) -> None:
|
||||||
"""Initialize UDP receiver."""
|
"""Initialize UDP receiver."""
|
||||||
self.context = Context()
|
self.context = Context()
|
||||||
self.hass = hass
|
self.hass = hass
|
||||||
|
|
||||||
|
assert entry_data.device_info is not None
|
||||||
|
self.device_info = entry_data.device_info
|
||||||
|
|
||||||
self.queue = asyncio.Queue()
|
self.queue = asyncio.Queue()
|
||||||
|
|
||||||
async def start_server(self) -> int:
|
async def start_server(self) -> int:
|
||||||
@ -155,7 +166,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
|||||||
context=self.context,
|
context=self.context,
|
||||||
event_callback=handle_pipeline_event,
|
event_callback=handle_pipeline_event,
|
||||||
stt_metadata=stt.SpeechMetadata(
|
stt_metadata=stt.SpeechMetadata(
|
||||||
language="",
|
language="", # set in async_pipeline_from_audio_stream
|
||||||
format=stt.AudioFormats.WAV,
|
format=stt.AudioFormats.WAV,
|
||||||
codec=stt.AudioCodecs.PCM,
|
codec=stt.AudioCodecs.PCM,
|
||||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||||
@ -163,4 +174,7 @@ class VoiceAssistantUDPServer(asyncio.DatagramProtocol):
|
|||||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||||
),
|
),
|
||||||
stt_stream=self._iterate_packets(),
|
stt_stream=self._iterate_packets(),
|
||||||
|
pipeline_id=pipeline_select.get_chosen_pipeline(
|
||||||
|
self.hass, DOMAIN, self.device_info.mac_address
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
@ -8,6 +8,8 @@ import async_timeout
|
|||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.components import assist_pipeline, esphome
|
from homeassistant.components import assist_pipeline, esphome
|
||||||
|
from homeassistant.components.esphome import DomainData
|
||||||
|
from homeassistant.components.esphome.voice_assistant import VoiceAssistantUDPServer
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
|
||||||
_TEST_INPUT_TEXT = "This is an input test"
|
_TEST_INPUT_TEXT = "This is an input test"
|
||||||
@ -15,7 +17,19 @@ _TEST_OUTPUT_TEXT = "This is an output test"
|
|||||||
_TEST_OUTPUT_URL = "output.mp3"
|
_TEST_OUTPUT_URL = "output.mp3"
|
||||||
|
|
||||||
|
|
||||||
async def test_pipeline_events(hass: HomeAssistant) -> None:
|
@pytest.fixture
|
||||||
|
def voice_assistant_udp_server_v1(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_voice_assistant_v1_entry,
|
||||||
|
) -> VoiceAssistantUDPServer:
|
||||||
|
"""Return the UDP server."""
|
||||||
|
entry_data = DomainData.get(hass).get_entry_data(mock_voice_assistant_v1_entry)
|
||||||
|
return VoiceAssistantUDPServer(hass, entry_data)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_pipeline_events(
|
||||||
|
hass: HomeAssistant, voice_assistant_udp_server_v1: VoiceAssistantUDPServer
|
||||||
|
) -> None:
|
||||||
"""Test that the pipeline function is called."""
|
"""Test that the pipeline function is called."""
|
||||||
|
|
||||||
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
async def async_pipeline_from_audio_stream(*args, **kwargs):
|
||||||
@ -67,75 +81,74 @@ async def test_pipeline_events(hass: HomeAssistant) -> None:
|
|||||||
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
|
"homeassistant.components.esphome.voice_assistant.async_pipeline_from_audio_stream",
|
||||||
new=async_pipeline_from_audio_stream,
|
new=async_pipeline_from_audio_stream,
|
||||||
):
|
):
|
||||||
server = esphome.voice_assistant.VoiceAssistantUDPServer(hass)
|
voice_assistant_udp_server_v1.transport = Mock()
|
||||||
server.transport = Mock()
|
|
||||||
|
|
||||||
await server.run_pipeline(handle_event)
|
await voice_assistant_udp_server_v1.run_pipeline(handle_event)
|
||||||
|
|
||||||
|
|
||||||
async def test_udp_server(
|
async def test_udp_server(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
socket_enabled,
|
socket_enabled,
|
||||||
unused_udp_port_factory,
|
unused_udp_port_factory,
|
||||||
|
voice_assistant_udp_server_v1: VoiceAssistantUDPServer,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test the UDP server runs and queues incoming data."""
|
"""Test the UDP server runs and queues incoming data."""
|
||||||
port_to_use = unused_udp_port_factory()
|
port_to_use = unused_udp_port_factory()
|
||||||
|
|
||||||
server = esphome.voice_assistant.VoiceAssistantUDPServer(hass)
|
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.esphome.voice_assistant.UDP_PORT", new=port_to_use
|
"homeassistant.components.esphome.voice_assistant.UDP_PORT", new=port_to_use
|
||||||
):
|
):
|
||||||
port = await server.start_server()
|
port = await voice_assistant_udp_server_v1.start_server()
|
||||||
assert port == port_to_use
|
assert port == port_to_use
|
||||||
|
|
||||||
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||||
|
|
||||||
assert server.queue.qsize() == 0
|
assert voice_assistant_udp_server_v1.queue.qsize() == 0
|
||||||
sock.sendto(b"test", ("127.0.0.1", port))
|
sock.sendto(b"test", ("127.0.0.1", port))
|
||||||
|
|
||||||
# Give the socket some time to send/receive the data
|
# Give the socket some time to send/receive the data
|
||||||
async with async_timeout.timeout(1):
|
async with async_timeout.timeout(1):
|
||||||
while server.queue.qsize() == 0:
|
while voice_assistant_udp_server_v1.queue.qsize() == 0:
|
||||||
await asyncio.sleep(0.1)
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
assert server.queue.qsize() == 1
|
assert voice_assistant_udp_server_v1.queue.qsize() == 1
|
||||||
|
|
||||||
server.stop()
|
voice_assistant_udp_server_v1.stop()
|
||||||
|
|
||||||
assert server.transport.is_closing()
|
assert voice_assistant_udp_server_v1.transport.is_closing()
|
||||||
|
|
||||||
|
|
||||||
async def test_udp_server_multiple(
|
async def test_udp_server_multiple(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
socket_enabled,
|
socket_enabled,
|
||||||
unused_udp_port_factory,
|
unused_udp_port_factory,
|
||||||
|
voice_assistant_udp_server_v1: VoiceAssistantUDPServer,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test that the UDP server raises an error if started twice."""
|
"""Test that the UDP server raises an error if started twice."""
|
||||||
server = esphome.voice_assistant.VoiceAssistantUDPServer(hass)
|
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.esphome.voice_assistant.UDP_PORT",
|
"homeassistant.components.esphome.voice_assistant.UDP_PORT",
|
||||||
new=unused_udp_port_factory(),
|
new=unused_udp_port_factory(),
|
||||||
):
|
):
|
||||||
await server.start_server()
|
await voice_assistant_udp_server_v1.start_server()
|
||||||
|
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.esphome.voice_assistant.UDP_PORT",
|
"homeassistant.components.esphome.voice_assistant.UDP_PORT",
|
||||||
new=unused_udp_port_factory(),
|
new=unused_udp_port_factory(),
|
||||||
), pytest.raises(RuntimeError):
|
), pytest.raises(RuntimeError):
|
||||||
pass
|
pass
|
||||||
await server.start_server()
|
await voice_assistant_udp_server_v1.start_server()
|
||||||
|
|
||||||
|
|
||||||
async def test_udp_server_after_stopped(
|
async def test_udp_server_after_stopped(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
socket_enabled,
|
socket_enabled,
|
||||||
unused_udp_port_factory,
|
unused_udp_port_factory,
|
||||||
|
voice_assistant_udp_server_v1: VoiceAssistantUDPServer,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test that the UDP server raises an error if started after stopped."""
|
"""Test that the UDP server raises an error if started after stopped."""
|
||||||
server = esphome.voice_assistant.VoiceAssistantUDPServer(hass)
|
voice_assistant_udp_server_v1.stop()
|
||||||
server.stop()
|
|
||||||
with patch(
|
with patch(
|
||||||
"homeassistant.components.esphome.voice_assistant.UDP_PORT",
|
"homeassistant.components.esphome.voice_assistant.UDP_PORT",
|
||||||
new=unused_udp_port_factory(),
|
new=unused_udp_port_factory(),
|
||||||
), pytest.raises(RuntimeError):
|
), pytest.raises(RuntimeError):
|
||||||
await server.start_server()
|
await voice_assistant_udp_server_v1.start_server()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user