core/tests/components/esphome/test_assist_satellite.py
J. Nick Koston 3aa1c60fe3
ESPHome quality improvements round 2 (#143613)
Co-authored-by: Joost Lekkerkerker <joostlek@outlook.com>
2025-04-24 09:51:33 -10:00

1944 lines
63 KiB
Python

"""Test ESPHome voice assistant server."""
import asyncio
from collections.abc import Awaitable, Callable
from dataclasses import replace
import io
import socket
from unittest.mock import ANY, AsyncMock, Mock, patch
import wave
from aioesphomeapi import (
APIClient,
EntityInfo,
EntityState,
MediaPlayerFormatPurpose,
MediaPlayerInfo,
MediaPlayerSupportedFormat,
UserService,
VoiceAssistantAnnounceFinished,
VoiceAssistantAudioSettings,
VoiceAssistantCommandFlag,
VoiceAssistantEventType,
VoiceAssistantFeature,
VoiceAssistantTimerEventType,
)
import pytest
from homeassistant.components import (
assist_pipeline,
assist_satellite,
conversation,
tts,
)
from homeassistant.components.assist_pipeline import PipelineEvent, PipelineEventType
from homeassistant.components.assist_satellite import (
AssistSatelliteConfiguration,
AssistSatelliteEntityFeature,
AssistSatelliteWakeWord,
)
# pylint: disable-next=hass-component-root-import
from homeassistant.components.assist_satellite.entity import AssistSatelliteState
from homeassistant.components.esphome.assist_satellite import VoiceAssistantUDPServer
from homeassistant.components.select import (
DOMAIN as SELECT_DOMAIN,
SERVICE_SELECT_OPTION,
)
from homeassistant.const import ATTR_ENTITY_ID, STATE_UNAVAILABLE
from homeassistant.core import HomeAssistant
from homeassistant.helpers import device_registry as dr, intent as intent_helper
from homeassistant.helpers.network import get_url
from .common import get_satellite_entity
from .conftest import MockESPHomeDevice
from tests.components.tts.common import MockResultStream
@pytest.fixture
def mock_wav() -> bytes:
"""Return test WAV audio."""
with io.BytesIO() as wav_io:
with wave.open(wav_io, "wb") as wav_file:
wav_file.setframerate(16000)
wav_file.setsampwidth(2)
wav_file.setnchannels(1)
wav_file.writeframes(b"test-wav")
return wav_io.getvalue()
async def test_no_satellite_without_voice_assistant(
hass: HomeAssistant,
mock_client: APIClient,
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
) -> None:
"""Test that an assist satellite entity is not created if a voice assistant is not present."""
mock_device: MockESPHomeDevice = await mock_esphome_device(
mock_client=mock_client,
entity_info=[],
user_service=[],
states=[],
device_info={},
)
await hass.async_block_till_done()
# No satellite entity should be created
assert get_satellite_entity(hass, mock_device.device_info.mac_address) is None
async def test_pipeline_api_audio(
hass: HomeAssistant,
device_registry: dr.DeviceRegistry,
mock_client: APIClient,
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
mock_wav: bytes,
) -> None:
"""Test a complete pipeline run with API audio (over the TCP connection)."""
conversation_id = "test-conversation-id"
mock_device: MockESPHomeDevice = await mock_esphome_device(
mock_client=mock_client,
entity_info=[],
user_service=[],
states=[],
device_info={
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
| VoiceAssistantFeature.SPEAKER
| VoiceAssistantFeature.API_AUDIO
},
)
await hass.async_block_till_done()
dev = device_registry.async_get_device(
connections={(dr.CONNECTION_NETWORK_MAC, mock_device.entry.unique_id)}
)
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
assert satellite is not None
# Block TTS streaming until we're ready.
# This makes it easier to verify the order of pipeline events.
stream_tts_audio_ready = asyncio.Event()
original_stream_tts_audio = satellite._stream_tts_audio
async def _stream_tts_audio(*args, **kwargs):
await stream_tts_audio_ready.wait()
await original_stream_tts_audio(*args, **kwargs)
async def async_pipeline_from_audio_stream(*args, device_id, **kwargs):
assert device_id == dev.id
stt_stream = kwargs["stt_stream"]
chunks = [chunk async for chunk in stt_stream]
# Verify test API audio
assert chunks == [b"test-mic"]
event_callback = kwargs["event_callback"]
# Test unknown event type
event_callback(
PipelineEvent(
type="unknown-event",
data={},
)
)
mock_client.send_voice_assistant_event.assert_not_called()
# Test error event
event_callback(
PipelineEvent(
type=PipelineEventType.ERROR,
data={"code": "test-error-code", "message": "test-error-message"},
)
)
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
{"code": "test-error-code", "message": "test-error-message"},
)
# Wake word
assert satellite.state == AssistSatelliteState.IDLE
event_callback(
PipelineEvent(
type=PipelineEventType.WAKE_WORD_START,
data={
"entity_id": "test-wake-word-entity-id",
"metadata": {},
"timeout": 0,
},
)
)
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_START,
{},
)
# Test no wake word detected
event_callback(
PipelineEvent(
type=PipelineEventType.WAKE_WORD_END, data={"wake_word_output": {}}
)
)
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
VoiceAssistantEventType.VOICE_ASSISTANT_ERROR,
{"code": "no_wake_word", "message": "No wake word detected"},
)
# Correct wake word detection
event_callback(
PipelineEvent(
type=PipelineEventType.WAKE_WORD_END,
data={"wake_word_output": {"wake_word_phrase": "test-wake-word"}},
)
)
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
VoiceAssistantEventType.VOICE_ASSISTANT_WAKE_WORD_END,
{},
)
# STT
event_callback(
PipelineEvent(
type=PipelineEventType.STT_START,
data={"engine": "test-stt-engine", "metadata": {}},
)
)
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
VoiceAssistantEventType.VOICE_ASSISTANT_STT_START,
{},
)
assert satellite.state == AssistSatelliteState.LISTENING
event_callback(
PipelineEvent(
type=PipelineEventType.STT_END,
data={"stt_output": {"text": "test-stt-text"}},
)
)
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
VoiceAssistantEventType.VOICE_ASSISTANT_STT_END,
{"text": "test-stt-text"},
)
# Intent
event_callback(
PipelineEvent(
type=PipelineEventType.INTENT_START,
data={
"engine": "test-intent-engine",
"language": hass.config.language,
"intent_input": "test-intent-text",
"conversation_id": conversation_id,
"device_id": device_id,
},
)
)
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_START,
{},
)
assert satellite.state == AssistSatelliteState.PROCESSING
event_callback(
PipelineEvent(
type=PipelineEventType.INTENT_END,
data={
"intent_output": conversation.ConversationResult(
response=intent_helper.IntentResponse("en"),
conversation_id=conversation_id,
continue_conversation=True,
).as_dict()
},
)
)
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
VoiceAssistantEventType.VOICE_ASSISTANT_INTENT_END,
{
"conversation_id": conversation_id,
"continue_conversation": "1",
},
)
# TTS
event_callback(
PipelineEvent(
type=PipelineEventType.TTS_START,
data={
"engine": "test-stt-engine",
"language": hass.config.language,
"voice": "test-voice",
"tts_input": "test-tts-text",
},
)
)
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_START,
{"text": "test-tts-text"},
)
assert satellite.state == AssistSatelliteState.RESPONDING
# Should return mock_wav audio
mock_tts_result_stream = MockResultStream(hass, "wav", mock_wav)
event_callback(
PipelineEvent(
type=PipelineEventType.TTS_END,
data={
"tts_output": {
"media_id": "test-media-id",
"url": mock_tts_result_stream.url,
"token": mock_tts_result_stream.token,
}
},
)
)
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_END,
{"url": get_url(hass) + mock_tts_result_stream.url},
)
event_callback(
PipelineEvent(
type=PipelineEventType.RUN_START,
data={
"tts_output": {
"media_id": "test-media-id",
"url": mock_tts_result_stream.url,
"token": mock_tts_result_stream.token,
}
},
)
)
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
VoiceAssistantEventType.VOICE_ASSISTANT_RUN_START,
{"url": get_url(hass) + mock_tts_result_stream.url},
)
event_callback(PipelineEvent(type=PipelineEventType.RUN_END))
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
VoiceAssistantEventType.VOICE_ASSISTANT_RUN_END,
{},
)
# Allow TTS streaming to proceed
stream_tts_audio_ready.set()
pipeline_finished = asyncio.Event()
original_handle_pipeline_finished = satellite.handle_pipeline_finished
def handle_pipeline_finished():
original_handle_pipeline_finished()
pipeline_finished.set()
tts_finished = asyncio.Event()
original_tts_response_finished = satellite.tts_response_finished
def tts_response_finished():
original_tts_response_finished()
tts_finished.set()
with (
patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch.object(satellite, "handle_pipeline_finished", handle_pipeline_finished),
patch.object(satellite, "_stream_tts_audio", _stream_tts_audio),
patch.object(satellite, "tts_response_finished", tts_response_finished),
):
# Should be cleared at pipeline start
satellite._audio_queue.put_nowait(b"leftover-data")
# Should be cancelled at pipeline start
mock_tts_streaming_task = Mock()
satellite._tts_streaming_task = mock_tts_streaming_task
async with asyncio.timeout(1):
await satellite.handle_pipeline_start(
conversation_id=conversation_id,
flags=VoiceAssistantCommandFlag.USE_WAKE_WORD,
audio_settings=VoiceAssistantAudioSettings(),
wake_word_phrase="",
)
mock_tts_streaming_task.cancel.assert_called_once()
await satellite.handle_audio(b"test-mic")
await satellite.handle_pipeline_stop(abort=False)
await pipeline_finished.wait()
await tts_finished.wait()
# Verify TTS streaming events.
# These are definitely the last two events because we blocked TTS streaming
# until after RUN_END above.
assert mock_client.send_voice_assistant_event.call_args_list[-2].args == (
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START,
{},
)
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END,
{},
)
# Verify TTS WAV audio chunk came through
mock_client.send_voice_assistant_audio.assert_called_once_with(b"test-wav")
@pytest.mark.usefixtures("socket_enabled")
async def test_pipeline_udp_audio(
hass: HomeAssistant,
mock_client: APIClient,
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
mock_wav: bytes,
) -> None:
"""Test a complete pipeline run with legacy UDP audio.
This test is not as comprehensive as test_pipeline_api_audio since we're
mainly focused on the UDP server.
"""
conversation_id = "test-conversation-id"
mock_device: MockESPHomeDevice = await mock_esphome_device(
mock_client=mock_client,
entity_info=[],
user_service=[],
states=[],
device_info={
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
| VoiceAssistantFeature.SPEAKER
},
)
await hass.async_block_till_done()
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
assert satellite is not None
mic_audio_event = asyncio.Event()
async def async_pipeline_from_audio_stream(*args, device_id, **kwargs):
stt_stream = kwargs["stt_stream"]
chunks = []
async for chunk in stt_stream:
chunks.append(chunk)
mic_audio_event.set()
# Verify test UDP audio
assert chunks == [b"test-mic"]
event_callback = kwargs["event_callback"]
# STT
event_callback(
PipelineEvent(
type=PipelineEventType.STT_START,
data={"engine": "test-stt-engine", "metadata": {}},
)
)
event_callback(
PipelineEvent(
type=PipelineEventType.STT_END,
data={"stt_output": {"text": "test-stt-text"}},
)
)
# Intent
event_callback(
PipelineEvent(
type=PipelineEventType.INTENT_START,
data={
"engine": "test-intent-engine",
"language": hass.config.language,
"intent_input": "test-intent-text",
"conversation_id": conversation_id,
"device_id": device_id,
},
)
)
event_callback(
PipelineEvent(
type=PipelineEventType.INTENT_END,
data={
"intent_output": conversation.ConversationResult(
response=intent_helper.IntentResponse("en"),
conversation_id=conversation_id,
).as_dict()
},
)
)
# TTS
event_callback(
PipelineEvent(
type=PipelineEventType.TTS_START,
data={
"engine": "test-stt-engine",
"language": hass.config.language,
"voice": "test-voice",
"tts_input": "test-tts-text",
},
)
)
# Should return mock_wav audio
mock_tts_result_stream = MockResultStream(hass, "wav", mock_wav)
event_callback(
PipelineEvent(
type=PipelineEventType.TTS_END,
data={
"tts_output": {
"media_id": "test-media-id",
"url": mock_tts_result_stream.url,
"token": mock_tts_result_stream.token,
}
},
)
)
event_callback(PipelineEvent(type=PipelineEventType.RUN_END))
pipeline_finished = asyncio.Event()
original_handle_pipeline_finished = satellite.handle_pipeline_finished
def handle_pipeline_finished():
original_handle_pipeline_finished()
pipeline_finished.set()
tts_finished = asyncio.Event()
original_tts_response_finished = satellite.tts_response_finished
def tts_response_finished():
original_tts_response_finished()
tts_finished.set()
class TestProtocol(asyncio.DatagramProtocol):
def __init__(self) -> None:
self.transport = None
self.data_received: list[bytes] = []
def connection_made(self, transport):
self.transport = transport
def datagram_received(self, data: bytes, addr):
self.data_received.append(data)
with (
patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch.object(satellite, "handle_pipeline_finished", handle_pipeline_finished),
patch.object(satellite, "tts_response_finished", tts_response_finished),
):
async with asyncio.timeout(1):
port = await satellite.handle_pipeline_start(
conversation_id=conversation_id,
flags=VoiceAssistantCommandFlag(0), # stt
audio_settings=VoiceAssistantAudioSettings(),
wake_word_phrase="",
)
assert (port is not None) and (port > 0)
(
transport,
protocol,
) = await asyncio.get_running_loop().create_datagram_endpoint(
TestProtocol, remote_addr=("127.0.0.1", port)
)
assert isinstance(protocol, TestProtocol)
# Send audio over UDP
transport.sendto(b"test-mic")
# Wait for audio chunk to be delivered
await mic_audio_event.wait()
await satellite.handle_pipeline_stop(abort=False)
await pipeline_finished.wait()
await tts_finished.wait()
# Verify TTS audio (from UDP)
assert protocol.data_received == [b"test-wav"]
# Check that UDP server was stopped
sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
sock.setblocking(False)
sock.bind(("", port)) # will fail if UDP server is still running
sock.close()
async def test_udp_errors() -> None:
"""Test UDP protocol error conditions."""
audio_queue: asyncio.Queue[bytes | None] = asyncio.Queue()
protocol = VoiceAssistantUDPServer(audio_queue)
protocol.datagram_received(b"test", ("", 0))
assert audio_queue.qsize() == 1
assert (await audio_queue.get()) == b"test"
# None will stop the pipeline
protocol.error_received(RuntimeError())
assert audio_queue.qsize() == 1
assert (await audio_queue.get()) is None
# No transport
assert protocol.transport is None
protocol.send_audio_bytes(b"test")
# No remote address
protocol.transport = Mock()
protocol.remote_addr = None
protocol.send_audio_bytes(b"test")
protocol.transport.sendto.assert_not_called()
async def test_pipeline_media_player(
hass: HomeAssistant,
mock_client: APIClient,
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
mock_wav: bytes,
) -> None:
"""Test a complete pipeline run with the TTS response sent to a media player instead of a speaker.
This test is not as comprehensive as test_pipeline_api_audio since we're
mainly focused on tts_response_finished getting automatically called.
"""
conversation_id = "test-conversation-id"
mock_device: MockESPHomeDevice = await mock_esphome_device(
mock_client=mock_client,
entity_info=[],
user_service=[],
states=[],
device_info={
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
| VoiceAssistantFeature.API_AUDIO
},
)
await hass.async_block_till_done()
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
assert satellite is not None
async def async_pipeline_from_audio_stream(*args, device_id, **kwargs):
stt_stream = kwargs["stt_stream"]
async for _chunk in stt_stream:
break
event_callback = kwargs["event_callback"]
# STT
event_callback(
PipelineEvent(
type=PipelineEventType.STT_START,
data={"engine": "test-stt-engine", "metadata": {}},
)
)
event_callback(
PipelineEvent(
type=PipelineEventType.STT_END,
data={"stt_output": {"text": "test-stt-text"}},
)
)
# Intent
event_callback(
PipelineEvent(
type=PipelineEventType.INTENT_START,
data={
"engine": "test-intent-engine",
"language": hass.config.language,
"intent_input": "test-intent-text",
"conversation_id": conversation_id,
"device_id": device_id,
},
)
)
event_callback(
PipelineEvent(
type=PipelineEventType.INTENT_END,
data={
"intent_output": conversation.ConversationResult(
response=intent_helper.IntentResponse("en"),
conversation_id=conversation_id,
).as_dict()
},
)
)
# TTS
event_callback(
PipelineEvent(
type=PipelineEventType.TTS_START,
data={
"engine": "test-stt-engine",
"language": hass.config.language,
"voice": "test-voice",
"tts_input": "test-tts-text",
},
)
)
# Should return mock_wav audio
mock_tts_result_stream = MockResultStream(hass, "wav", mock_wav)
event_callback(
PipelineEvent(
type=PipelineEventType.TTS_END,
data={
"tts_output": {
"media_id": "test-media-id",
"url": mock_tts_result_stream.url,
"token": mock_tts_result_stream.token,
}
},
)
)
event_callback(PipelineEvent(type=PipelineEventType.RUN_END))
pipeline_finished = asyncio.Event()
original_handle_pipeline_finished = satellite.handle_pipeline_finished
def handle_pipeline_finished():
original_handle_pipeline_finished()
pipeline_finished.set()
tts_finished = asyncio.Event()
original_tts_response_finished = satellite.tts_response_finished
def tts_response_finished():
original_tts_response_finished()
tts_finished.set()
with (
patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch.object(satellite, "handle_pipeline_finished", handle_pipeline_finished),
patch.object(satellite, "tts_response_finished", tts_response_finished),
):
async with asyncio.timeout(1):
await satellite.handle_pipeline_start(
conversation_id=conversation_id,
flags=VoiceAssistantCommandFlag(0), # stt
audio_settings=VoiceAssistantAudioSettings(),
wake_word_phrase="",
)
await satellite.handle_pipeline_stop(abort=False)
await pipeline_finished.wait()
assert satellite.state == AssistSatelliteState.RESPONDING
# Will trigger tts_response_finished
await mock_device.mock_voice_assistant_handle_announcement_finished(
VoiceAssistantAnnounceFinished(success=True)
)
await tts_finished.wait()
assert satellite.state == AssistSatelliteState.IDLE
async def test_timer_events(
hass: HomeAssistant,
device_registry: dr.DeviceRegistry,
mock_client: APIClient,
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
) -> None:
"""Test that injecting timer events results in the correct api client calls."""
mock_device: MockESPHomeDevice = await mock_esphome_device(
mock_client=mock_client,
entity_info=[],
user_service=[],
states=[],
device_info={
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
| VoiceAssistantFeature.TIMERS
},
)
await hass.async_block_till_done()
dev = device_registry.async_get_device(
connections={(dr.CONNECTION_NETWORK_MAC, mock_device.entry.unique_id)}
)
total_seconds = (1 * 60 * 60) + (2 * 60) + 3
await intent_helper.async_handle(
hass,
"test",
intent_helper.INTENT_START_TIMER,
{
"name": {"value": "test timer"},
"hours": {"value": 1},
"minutes": {"value": 2},
"seconds": {"value": 3},
},
device_id=dev.id,
)
mock_client.send_voice_assistant_timer_event.assert_called_with(
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_STARTED,
ANY,
"test timer",
total_seconds,
total_seconds,
True,
)
# Increase timer beyond original time and check total_seconds has increased
mock_client.send_voice_assistant_timer_event.reset_mock()
total_seconds += 5 * 60
await intent_helper.async_handle(
hass,
"test",
intent_helper.INTENT_INCREASE_TIMER,
{
"name": {"value": "test timer"},
"minutes": {"value": 5},
},
device_id=dev.id,
)
mock_client.send_voice_assistant_timer_event.assert_called_with(
VoiceAssistantTimerEventType.VOICE_ASSISTANT_TIMER_UPDATED,
ANY,
"test timer",
total_seconds,
ANY,
True,
)
async def test_unknown_timer_event(
hass: HomeAssistant,
device_registry: dr.DeviceRegistry,
mock_client: APIClient,
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
) -> None:
"""Test that unknown (new) timer event types do not result in api calls."""
mock_device: MockESPHomeDevice = await mock_esphome_device(
mock_client=mock_client,
entity_info=[],
user_service=[],
states=[],
device_info={
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
| VoiceAssistantFeature.TIMERS
},
)
await hass.async_block_till_done()
assert mock_device.entry.unique_id is not None
dev = device_registry.async_get_device(
connections={(dr.CONNECTION_NETWORK_MAC, mock_device.entry.unique_id)}
)
assert dev is not None
with patch(
"homeassistant.components.esphome.assist_satellite._TIMER_EVENT_TYPES.from_hass",
side_effect=KeyError,
):
await intent_helper.async_handle(
hass,
"test",
intent_helper.INTENT_START_TIMER,
{
"name": {"value": "test timer"},
"hours": {"value": 1},
"minutes": {"value": 2},
"seconds": {"value": 3},
},
device_id=dev.id,
)
mock_client.send_voice_assistant_timer_event.assert_not_called()
async def test_streaming_tts_errors(
hass: HomeAssistant,
mock_client: APIClient,
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
mock_wav: bytes,
) -> None:
"""Test error conditions for _stream_tts_audio function."""
mock_device: MockESPHomeDevice = await mock_esphome_device(
mock_client=mock_client,
entity_info=[],
user_service=[],
states=[],
device_info={
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
},
)
await hass.async_block_till_done()
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
assert satellite is not None
# Should not stream if not running
satellite._is_running = False
await satellite._stream_tts_audio(MockResultStream(hass, "wav", mock_wav))
mock_client.send_voice_assistant_audio.assert_not_called()
satellite._is_running = True
# Should only stream WAV
await satellite._stream_tts_audio(MockResultStream(hass, "mp3", b""))
mock_client.send_voice_assistant_audio.assert_not_called()
# Needs to be the correct sample rate, etc.
with io.BytesIO() as wav_io:
with wave.open(wav_io, "wb") as wav_file:
wav_file.setframerate(48000)
wav_file.setsampwidth(2)
wav_file.setnchannels(1)
wav_file.writeframes(b"test-wav")
mock_tts_result_stream = MockResultStream(hass, "wav", wav_io.getvalue())
await satellite._stream_tts_audio(mock_tts_result_stream)
mock_client.send_voice_assistant_audio.assert_not_called()
# Check that TTS_STREAM_* events still get sent after cancel
media_fetched = asyncio.Event()
mock_tts_result_stream = MockResultStream(hass, "wav", b"")
async def async_stream_result_slowly():
media_fetched.set()
await asyncio.sleep(1)
yield mock_wav
mock_tts_result_stream.async_stream_result = async_stream_result_slowly
mock_client.send_voice_assistant_event.reset_mock()
task = asyncio.create_task(satellite._stream_tts_audio(mock_tts_result_stream))
async with asyncio.timeout(1):
# Wait for media to be fetched
await media_fetched.wait()
# Cancel task
task.cancel()
await task
# No audio should have gone out
mock_client.send_voice_assistant_audio.assert_not_called()
assert len(mock_client.send_voice_assistant_event.call_args_list) == 2
# The TTS_STREAM_* events should have gone out
assert mock_client.send_voice_assistant_event.call_args_list[-2].args == (
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_START,
{},
)
assert mock_client.send_voice_assistant_event.call_args_list[-1].args == (
VoiceAssistantEventType.VOICE_ASSISTANT_TTS_STREAM_END,
{},
)
async def test_tts_format_from_media_player(
hass: HomeAssistant,
mock_client: APIClient,
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
) -> None:
"""Test that the text-to-speech format is pulled from the first media player."""
mock_device: MockESPHomeDevice = await mock_esphome_device(
mock_client=mock_client,
entity_info=[
MediaPlayerInfo(
object_id="mymedia_player",
key=1,
name="my media_player",
unique_id="my_media_player",
supports_pause=True,
supported_formats=[
MediaPlayerSupportedFormat(
format="flac",
sample_rate=48000,
num_channels=2,
purpose=MediaPlayerFormatPurpose.DEFAULT,
sample_bytes=2,
),
# This is the format that should be used for tts
MediaPlayerSupportedFormat(
format="mp3",
sample_rate=22050,
num_channels=1,
purpose=MediaPlayerFormatPurpose.ANNOUNCEMENT,
sample_bytes=2,
),
],
)
],
user_service=[],
states=[],
device_info={
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
},
)
await hass.async_block_till_done()
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
assert satellite is not None
with patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
) as mock_pipeline_from_audio_stream:
await satellite.handle_pipeline_start(
conversation_id="",
flags=0,
audio_settings=VoiceAssistantAudioSettings(),
wake_word_phrase=None,
)
mock_pipeline_from_audio_stream.assert_called_once()
kwargs = mock_pipeline_from_audio_stream.call_args_list[0].kwargs
# Should be ANNOUNCEMENT format from media player
assert kwargs.get("tts_audio_output") == {
tts.ATTR_PREFERRED_FORMAT: "mp3",
tts.ATTR_PREFERRED_SAMPLE_RATE: 22050,
tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 1,
tts.ATTR_PREFERRED_SAMPLE_BYTES: 2,
}
async def test_tts_minimal_format_from_media_player(
hass: HomeAssistant,
mock_client: APIClient,
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
) -> None:
"""Test text-to-speech format when media player only specifies the codec."""
mock_device: MockESPHomeDevice = await mock_esphome_device(
mock_client=mock_client,
entity_info=[
MediaPlayerInfo(
object_id="mymedia_player",
key=1,
name="my media_player",
unique_id="my_media_player",
supports_pause=True,
supported_formats=[
MediaPlayerSupportedFormat(
format="flac",
sample_rate=48000,
num_channels=2,
purpose=MediaPlayerFormatPurpose.DEFAULT,
sample_bytes=2,
),
# This is the format that should be used for tts
MediaPlayerSupportedFormat(
format="mp3",
sample_rate=0, # source rate
num_channels=0, # source channels
purpose=MediaPlayerFormatPurpose.ANNOUNCEMENT,
sample_bytes=0, # source width
),
],
)
],
user_service=[],
states=[],
device_info={
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
},
)
await hass.async_block_till_done()
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
assert satellite is not None
with patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
) as mock_pipeline_from_audio_stream:
await satellite.handle_pipeline_start(
conversation_id="",
flags=0,
audio_settings=VoiceAssistantAudioSettings(),
wake_word_phrase=None,
)
mock_pipeline_from_audio_stream.assert_called_once()
kwargs = mock_pipeline_from_audio_stream.call_args_list[0].kwargs
# Should be ANNOUNCEMENT format from media player
assert kwargs.get("tts_audio_output") == {
tts.ATTR_PREFERRED_FORMAT: "mp3",
}
async def test_announce_message(
hass: HomeAssistant,
mock_client: APIClient,
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
) -> None:
"""Test announcement with message."""
mock_device: MockESPHomeDevice = await mock_esphome_device(
mock_client=mock_client,
entity_info=[],
user_service=[],
states=[],
device_info={
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
| VoiceAssistantFeature.SPEAKER
| VoiceAssistantFeature.API_AUDIO
| VoiceAssistantFeature.ANNOUNCE
},
)
await hass.async_block_till_done()
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
assert satellite is not None
done = asyncio.Event()
async def send_voice_assistant_announcement_await_response(
media_id: str,
timeout: float,
text: str,
start_conversation: bool,
preannounce_media_id: str | None = None,
):
assert satellite.state == AssistSatelliteState.RESPONDING
assert media_id == "http://10.10.10.10:8123/api/tts_proxy/test-token"
assert text == "test-text"
assert not start_conversation
assert not preannounce_media_id
done.set()
with (
patch(
"homeassistant.components.tts.generate_media_source_id",
return_value="media-source://bla",
),
patch(
"homeassistant.components.tts.async_resolve_engine",
return_value="tts.cloud_tts",
),
patch(
"homeassistant.components.tts.async_create_stream",
return_value=MockResultStream(hass, "wav", b""),
),
patch.object(
mock_client,
"send_voice_assistant_announcement_await_response",
new=send_voice_assistant_announcement_await_response,
),
):
async with asyncio.timeout(1):
await hass.services.async_call(
assist_satellite.DOMAIN,
"announce",
{
ATTR_ENTITY_ID: satellite.entity_id,
"message": "test-text",
"preannounce": False,
},
blocking=True,
)
await done.wait()
assert satellite.state == AssistSatelliteState.IDLE
async def test_announce_media_id(
hass: HomeAssistant,
mock_client: APIClient,
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
device_registry: dr.DeviceRegistry,
) -> None:
"""Test announcement with media id."""
mock_device: MockESPHomeDevice = await mock_esphome_device(
mock_client=mock_client,
entity_info=[
MediaPlayerInfo(
object_id="mymedia_player",
key=1,
name="my media_player",
unique_id="my_media_player",
supports_pause=True,
supported_formats=[
MediaPlayerSupportedFormat(
format="flac",
sample_rate=48000,
num_channels=2,
purpose=MediaPlayerFormatPurpose.ANNOUNCEMENT,
sample_bytes=2,
),
],
)
],
user_service=[],
states=[],
device_info={
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
| VoiceAssistantFeature.SPEAKER
| VoiceAssistantFeature.API_AUDIO
| VoiceAssistantFeature.ANNOUNCE
},
)
await hass.async_block_till_done()
dev = device_registry.async_get_device(
connections={(dr.CONNECTION_NETWORK_MAC, mock_device.entry.unique_id)}
)
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
assert satellite is not None
done = asyncio.Event()
async def send_voice_assistant_announcement_await_response(
media_id: str,
timeout: float,
text: str,
start_conversation: bool,
preannounce_media_id: str | None = None,
):
assert satellite.state == AssistSatelliteState.RESPONDING
assert media_id == "https://www.home-assistant.io/proxied.flac"
assert not start_conversation
assert not preannounce_media_id
done.set()
with (
patch.object(
mock_client,
"send_voice_assistant_announcement_await_response",
new=send_voice_assistant_announcement_await_response,
),
patch(
"homeassistant.components.esphome.assist_satellite.async_create_proxy_url",
return_value="https://www.home-assistant.io/proxied.flac",
) as mock_async_create_proxy_url,
):
async with asyncio.timeout(1):
await hass.services.async_call(
assist_satellite.DOMAIN,
"announce",
{
ATTR_ENTITY_ID: satellite.entity_id,
"media_id": "https://www.home-assistant.io/resolved.mp3",
"preannounce": False,
},
blocking=True,
)
await done.wait()
assert satellite.state == AssistSatelliteState.IDLE
mock_async_create_proxy_url.assert_called_once_with(
hass=hass,
device_id=dev.id,
media_url="https://www.home-assistant.io/resolved.mp3",
media_format="flac",
rate=48000,
channels=2,
width=2,
)
async def test_announce_message_with_preannounce(
hass: HomeAssistant,
mock_client: APIClient,
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
) -> None:
"""Test announcement with message and preannounce media id."""
mock_device: MockESPHomeDevice = await mock_esphome_device(
mock_client=mock_client,
entity_info=[],
user_service=[],
states=[],
device_info={
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
| VoiceAssistantFeature.SPEAKER
| VoiceAssistantFeature.API_AUDIO
| VoiceAssistantFeature.ANNOUNCE
},
)
await hass.async_block_till_done()
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
assert satellite is not None
done = asyncio.Event()
async def send_voice_assistant_announcement_await_response(
media_id: str,
timeout: float,
text: str,
start_conversation: bool,
preannounce_media_id: str | None = None,
):
assert satellite.state == AssistSatelliteState.RESPONDING
assert media_id == "http://10.10.10.10:8123/api/tts_proxy/test-token"
assert text == "test-text"
assert not start_conversation
assert preannounce_media_id == "test-preannounce"
done.set()
with (
patch(
"homeassistant.components.tts.generate_media_source_id",
return_value="media-source://bla",
),
patch(
"homeassistant.components.tts.async_resolve_engine",
return_value="tts.cloud_tts",
),
patch(
"homeassistant.components.tts.async_create_stream",
return_value=MockResultStream(hass, "wav", b""),
),
patch.object(
mock_client,
"send_voice_assistant_announcement_await_response",
new=send_voice_assistant_announcement_await_response,
),
):
async with asyncio.timeout(1):
await hass.services.async_call(
assist_satellite.DOMAIN,
"announce",
{
ATTR_ENTITY_ID: satellite.entity_id,
"message": "test-text",
"preannounce_media_id": "test-preannounce",
},
blocking=True,
)
await done.wait()
assert satellite.state == AssistSatelliteState.IDLE
async def test_non_default_supported_features(
hass: HomeAssistant,
mock_client: APIClient,
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
) -> None:
"""Test that the start conversation and announce are not set by default."""
mock_device: MockESPHomeDevice = await mock_esphome_device(
mock_client=mock_client,
entity_info=[],
user_service=[],
states=[],
device_info={
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
},
)
await hass.async_block_till_done()
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
assert satellite is not None
assert not (
satellite.supported_features & AssistSatelliteEntityFeature.START_CONVERSATION
)
assert not (satellite.supported_features & AssistSatelliteEntityFeature.ANNOUNCE)
async def test_start_conversation_message(
hass: HomeAssistant,
mock_client: APIClient,
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
) -> None:
"""Test start conversation with message."""
mock_device: MockESPHomeDevice = await mock_esphome_device(
mock_client=mock_client,
entity_info=[],
user_service=[],
states=[],
device_info={
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
| VoiceAssistantFeature.SPEAKER
| VoiceAssistantFeature.API_AUDIO
| VoiceAssistantFeature.ANNOUNCE
| VoiceAssistantFeature.START_CONVERSATION
},
)
await hass.async_block_till_done()
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
assert satellite is not None
pipeline = assist_pipeline.Pipeline(
conversation_engine="test engine",
conversation_language="en",
language="en",
name="test pipeline",
stt_engine="test stt",
stt_language="en",
tts_engine="test tts",
tts_language="en",
tts_voice=None,
wake_word_entity=None,
wake_word_id=None,
)
done = asyncio.Event()
async def send_voice_assistant_announcement_await_response(
media_id: str,
timeout: float,
text: str,
start_conversation: bool,
preannounce_media_id: str,
):
assert satellite.state == AssistSatelliteState.RESPONDING
assert media_id == "http://10.10.10.10:8123/api/tts_proxy/test-token"
assert text == "test-text"
assert start_conversation
assert not preannounce_media_id
done.set()
with (
patch(
"homeassistant.components.tts.generate_media_source_id",
return_value="media-source://bla",
),
patch(
"homeassistant.components.tts.async_resolve_engine",
return_value="tts.cloud_tts",
),
patch(
"homeassistant.components.tts.async_create_stream",
return_value=MockResultStream(hass, "wav", b""),
),
patch.object(
mock_client,
"send_voice_assistant_announcement_await_response",
new=send_voice_assistant_announcement_await_response,
),
patch(
"homeassistant.components.assist_satellite.entity.async_get_pipeline",
return_value=pipeline,
),
):
async with asyncio.timeout(1):
await hass.services.async_call(
assist_satellite.DOMAIN,
"start_conversation",
{
ATTR_ENTITY_ID: satellite.entity_id,
"start_message": "test-text",
"preannounce": False,
},
blocking=True,
)
await done.wait()
assert satellite.state == AssistSatelliteState.IDLE
async def test_start_conversation_media_id(
hass: HomeAssistant,
mock_client: APIClient,
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
device_registry: dr.DeviceRegistry,
) -> None:
"""Test start conversation with media id."""
mock_device: MockESPHomeDevice = await mock_esphome_device(
mock_client=mock_client,
entity_info=[
MediaPlayerInfo(
object_id="mymedia_player",
key=1,
name="my media_player",
unique_id="my_media_player",
supports_pause=True,
supported_formats=[
MediaPlayerSupportedFormat(
format="flac",
sample_rate=48000,
num_channels=2,
purpose=MediaPlayerFormatPurpose.ANNOUNCEMENT,
sample_bytes=2,
),
],
)
],
user_service=[],
states=[],
device_info={
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
| VoiceAssistantFeature.SPEAKER
| VoiceAssistantFeature.API_AUDIO
| VoiceAssistantFeature.ANNOUNCE
| VoiceAssistantFeature.START_CONVERSATION
},
)
await hass.async_block_till_done()
dev = device_registry.async_get_device(
connections={(dr.CONNECTION_NETWORK_MAC, mock_device.entry.unique_id)}
)
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
assert satellite is not None
pipeline = assist_pipeline.Pipeline(
conversation_engine="test engine",
conversation_language="en",
language="en",
name="test pipeline",
stt_engine="test stt",
stt_language="en",
tts_engine="test tts",
tts_language="en",
tts_voice=None,
wake_word_entity=None,
wake_word_id=None,
)
done = asyncio.Event()
async def send_voice_assistant_announcement_await_response(
media_id: str,
timeout: float,
text: str,
start_conversation: bool,
preannounce_media_id: str,
):
assert satellite.state == AssistSatelliteState.RESPONDING
assert media_id == "https://www.home-assistant.io/proxied.flac"
assert start_conversation
assert not preannounce_media_id
done.set()
with (
patch.object(
mock_client,
"send_voice_assistant_announcement_await_response",
new=send_voice_assistant_announcement_await_response,
),
patch(
"homeassistant.components.esphome.assist_satellite.async_create_proxy_url",
return_value="https://www.home-assistant.io/proxied.flac",
) as mock_async_create_proxy_url,
patch(
"homeassistant.components.assist_satellite.entity.async_get_pipeline",
return_value=pipeline,
),
):
async with asyncio.timeout(1):
await hass.services.async_call(
assist_satellite.DOMAIN,
"start_conversation",
{
ATTR_ENTITY_ID: satellite.entity_id,
"start_media_id": "https://www.home-assistant.io/resolved.mp3",
"preannounce": False,
},
blocking=True,
)
await done.wait()
assert satellite.state == AssistSatelliteState.IDLE
mock_async_create_proxy_url.assert_called_once_with(
hass=hass,
device_id=dev.id,
media_url="https://www.home-assistant.io/resolved.mp3",
media_format="flac",
rate=48000,
channels=2,
width=2,
)
async def test_start_conversation_message_with_preannounce(
hass: HomeAssistant,
mock_client: APIClient,
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
) -> None:
"""Test start conversation with message and preannounce media id."""
mock_device: MockESPHomeDevice = await mock_esphome_device(
mock_client=mock_client,
entity_info=[],
user_service=[],
states=[],
device_info={
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
| VoiceAssistantFeature.SPEAKER
| VoiceAssistantFeature.API_AUDIO
| VoiceAssistantFeature.ANNOUNCE
| VoiceAssistantFeature.START_CONVERSATION
},
)
await hass.async_block_till_done()
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
assert satellite is not None
pipeline = assist_pipeline.Pipeline(
conversation_engine="test engine",
conversation_language="en",
language="en",
name="test pipeline",
stt_engine="test stt",
stt_language="en",
tts_engine="test tts",
tts_language="en",
tts_voice=None,
wake_word_entity=None,
wake_word_id=None,
)
done = asyncio.Event()
async def send_voice_assistant_announcement_await_response(
media_id: str,
timeout: float,
text: str,
start_conversation: bool,
preannounce_media_id: str,
):
assert satellite.state == AssistSatelliteState.RESPONDING
assert media_id == "http://10.10.10.10:8123/api/tts_proxy/test-token"
assert text == "test-text"
assert start_conversation
assert preannounce_media_id == "test-preannounce"
done.set()
with (
patch(
"homeassistant.components.tts.generate_media_source_id",
return_value="media-source://bla",
),
patch(
"homeassistant.components.tts.async_resolve_engine",
return_value="tts.cloud_tts",
),
patch(
"homeassistant.components.tts.async_create_stream",
return_value=MockResultStream(hass, "wav", b""),
),
patch.object(
mock_client,
"send_voice_assistant_announcement_await_response",
new=send_voice_assistant_announcement_await_response,
),
patch(
"homeassistant.components.assist_satellite.entity.async_get_pipeline",
return_value=pipeline,
),
):
async with asyncio.timeout(1):
await hass.services.async_call(
assist_satellite.DOMAIN,
"start_conversation",
{
ATTR_ENTITY_ID: satellite.entity_id,
"start_message": "test-text",
"preannounce_media_id": "test-preannounce",
},
blocking=True,
)
await done.wait()
assert satellite.state == AssistSatelliteState.IDLE
async def test_satellite_unloaded_on_disconnect(
hass: HomeAssistant,
mock_client: APIClient,
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
) -> None:
"""Test that the assist satellite platform is unloaded on disconnect."""
mock_device: MockESPHomeDevice = await mock_esphome_device(
mock_client=mock_client,
entity_info=[],
user_service=[],
states=[],
device_info={
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
},
)
await hass.async_block_till_done()
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
assert satellite is not None
state = hass.states.get(satellite.entity_id)
assert state is not None
assert state.state != STATE_UNAVAILABLE
# Device will be unavailable after disconnect
await mock_device.mock_disconnect(True)
state = hass.states.get(satellite.entity_id)
assert state is not None
assert state.state == STATE_UNAVAILABLE
async def test_pipeline_abort(
hass: HomeAssistant,
mock_client: APIClient,
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
) -> None:
"""Test aborting a pipeline (no further processing)."""
mock_device: MockESPHomeDevice = await mock_esphome_device(
mock_client=mock_client,
entity_info=[],
user_service=[],
states=[],
device_info={
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
| VoiceAssistantFeature.API_AUDIO
},
)
await hass.async_block_till_done()
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
assert satellite is not None
chunks = []
chunk_received = asyncio.Event()
pipeline_aborted = asyncio.Event()
async def async_pipeline_from_audio_stream(*args, **kwargs):
stt_stream = kwargs["stt_stream"]
try:
async for chunk in stt_stream:
chunks.append(chunk)
chunk_received.set()
except asyncio.CancelledError:
# Aborting cancels the pipeline task
pipeline_aborted.set()
raise
pipeline_finished = asyncio.Event()
original_handle_pipeline_finished = satellite.handle_pipeline_finished
def handle_pipeline_finished():
original_handle_pipeline_finished()
pipeline_finished.set()
with (
patch(
"homeassistant.components.assist_satellite.entity.async_pipeline_from_audio_stream",
new=async_pipeline_from_audio_stream,
),
patch.object(satellite, "handle_pipeline_finished", handle_pipeline_finished),
):
async with asyncio.timeout(1):
await satellite.handle_pipeline_start(
conversation_id="",
flags=VoiceAssistantCommandFlag(0), # stt
audio_settings=VoiceAssistantAudioSettings(),
wake_word_phrase="",
)
await satellite.handle_audio(b"before-abort")
await chunk_received.wait()
# Abort the pipeline, no further processing
await satellite.handle_pipeline_stop(abort=True)
await pipeline_aborted.wait()
# This chunk should not make it into the STT stream
await satellite.handle_audio(b"after-abort")
await pipeline_finished.wait()
# Only first chunk
assert chunks == [b"before-abort"]
async def test_get_set_configuration(
hass: HomeAssistant,
mock_client: APIClient,
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
) -> None:
"""Test getting and setting the satellite configuration."""
expected_config = AssistSatelliteConfiguration(
available_wake_words=[
AssistSatelliteWakeWord("1234", "okay nabu", ["en"]),
AssistSatelliteWakeWord("5678", "hey jarvis", ["en"]),
],
active_wake_words=["1234"],
max_active_wake_words=1,
)
mock_client.get_voice_assistant_configuration.return_value = expected_config
mock_device: MockESPHomeDevice = await mock_esphome_device(
mock_client=mock_client,
entity_info=[],
user_service=[],
states=[],
device_info={
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
| VoiceAssistantFeature.ANNOUNCE
},
)
await hass.async_block_till_done()
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
assert satellite is not None
# HA should have been updated
actual_config = satellite.async_get_configuration()
assert actual_config == expected_config
updated_config = replace(actual_config, active_wake_words=["5678"])
mock_client.get_voice_assistant_configuration.return_value = updated_config
# Change active wake words
await satellite.async_set_configuration(updated_config)
# Set config method should be called
mock_client.set_voice_assistant_configuration.assert_called_once_with(
active_wake_words=["5678"]
)
# Device should have been updated
assert satellite.async_get_configuration() == updated_config
async def test_wake_word_select(
hass: HomeAssistant,
mock_client: APIClient,
mock_esphome_device: Callable[
[APIClient, list[EntityInfo], list[UserService], list[EntityState]],
Awaitable[MockESPHomeDevice],
],
) -> None:
"""Test wake word select."""
device_config = AssistSatelliteConfiguration(
available_wake_words=[
AssistSatelliteWakeWord("okay_nabu", "Okay Nabu", ["en"]),
AssistSatelliteWakeWord("hey_jarvis", "Hey Jarvis", ["en"]),
AssistSatelliteWakeWord("hey_mycroft", "Hey Mycroft", ["en"]),
],
active_wake_words=["hey_jarvis"],
max_active_wake_words=1,
)
mock_client.get_voice_assistant_configuration.return_value = device_config
# Wrap mock so we can tell when it's done
configuration_set = asyncio.Event()
async def wrapper(*args, **kwargs):
# Update device config because entity will request it after update
device_config.active_wake_words = kwargs["active_wake_words"]
configuration_set.set()
mock_client.set_voice_assistant_configuration = AsyncMock(side_effect=wrapper)
mock_device: MockESPHomeDevice = await mock_esphome_device(
mock_client=mock_client,
entity_info=[],
user_service=[],
states=[],
device_info={
"voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
| VoiceAssistantFeature.ANNOUNCE
},
)
await hass.async_block_till_done()
satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
assert satellite is not None
assert satellite.async_get_configuration().active_wake_words == ["hey_jarvis"]
# Active wake word should be selected
state = hass.states.get("select.test_wake_word")
assert state is not None
assert state.state == "Hey Jarvis"
# Changing the select should set the active wake word
await hass.services.async_call(
SELECT_DOMAIN,
SERVICE_SELECT_OPTION,
{ATTR_ENTITY_ID: "select.test_wake_word", "option": "Okay Nabu"},
blocking=True,
)
await hass.async_block_till_done()
state = hass.states.get("select.test_wake_word")
assert state is not None
assert state.state == "Okay Nabu"
# Wait for device config to be updated
async with asyncio.timeout(1):
await configuration_set.wait()
# Satellite config should have been updated
assert satellite.async_get_configuration().active_wake_words == ["okay_nabu"]