Move Assist Pipeline tests to right file (#144696)

This commit is contained in:
Paulus Schoutsen 2025-05-11 15:38:21 -04:00 committed by GitHub
parent b394c07a3d
commit 4faa920318
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 1079 additions and 1051 deletions

View File

@ -1,5 +1,10 @@
"""Tests for the Voice Assistant integration."""
from dataclasses import asdict
from unittest.mock import ANY
from homeassistant.components import assist_pipeline
MANY_LANGUAGES = [
"ar",
"bg",
@ -54,3 +59,16 @@ MANY_LANGUAGES = [
"zh-hk",
"zh-tw",
]
def process_events(events: list[assist_pipeline.PipelineEvent]) -> list[dict]:
"""Process events to remove dynamic values."""
processed = []
for event in events:
as_dict = asdict(event)
as_dict.pop("timestamp")
if as_dict["type"] == assist_pipeline.PipelineEventType.RUN_START:
as_dict["data"]["pipeline"] = ANY
processed.append(as_dict)
return processed

View File

@ -461,204 +461,3 @@
}),
])
# ---
# name: test_pipeline_language_used_instead_of_conversation_language
list([
dict({
'data': dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
}),
'type': <PipelineEventType.RUN_START: 'run-start'>,
}),
dict({
'data': dict({
'conversation_id': 'mock-ulid',
'device_id': None,
'engine': 'conversation.home_assistant',
'intent_input': 'test input',
'language': 'en',
'prefer_local_intents': False,
}),
'type': <PipelineEventType.INTENT_START: 'intent-start'>,
}),
dict({
'data': dict({
'intent_output': dict({
'continue_conversation': False,
'conversation_id': <ANY>,
'response': dict({
'card': dict({
}),
'data': dict({
'failed': list([
]),
'success': list([
]),
'targets': list([
]),
}),
'language': 'en',
'response_type': 'action_done',
'speech': dict({
}),
}),
}),
'processed_locally': True,
}),
'type': <PipelineEventType.INTENT_END: 'intent-end'>,
}),
dict({
'data': None,
'type': <PipelineEventType.RUN_END: 'run-end'>,
}),
])
# ---
# name: test_stt_language_used_instead_of_conversation_language
list([
dict({
'data': dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
}),
'type': <PipelineEventType.RUN_START: 'run-start'>,
}),
dict({
'data': dict({
'conversation_id': 'mock-ulid',
'device_id': None,
'engine': 'conversation.home_assistant',
'intent_input': 'test input',
'language': 'en-US',
'prefer_local_intents': False,
}),
'type': <PipelineEventType.INTENT_START: 'intent-start'>,
}),
dict({
'data': dict({
'intent_output': dict({
'continue_conversation': False,
'conversation_id': <ANY>,
'response': dict({
'card': dict({
}),
'data': dict({
'failed': list([
]),
'success': list([
]),
'targets': list([
]),
}),
'language': 'en',
'response_type': 'action_done',
'speech': dict({
}),
}),
}),
'processed_locally': True,
}),
'type': <PipelineEventType.INTENT_END: 'intent-end'>,
}),
dict({
'data': None,
'type': <PipelineEventType.RUN_END: 'run-end'>,
}),
])
# ---
# name: test_tts_language_used_instead_of_conversation_language
list([
dict({
'data': dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
}),
'type': <PipelineEventType.RUN_START: 'run-start'>,
}),
dict({
'data': dict({
'conversation_id': 'mock-ulid',
'device_id': None,
'engine': 'conversation.home_assistant',
'intent_input': 'test input',
'language': 'en-us',
'prefer_local_intents': False,
}),
'type': <PipelineEventType.INTENT_START: 'intent-start'>,
}),
dict({
'data': dict({
'intent_output': dict({
'continue_conversation': False,
'conversation_id': <ANY>,
'response': dict({
'card': dict({
}),
'data': dict({
'failed': list([
]),
'success': list([
]),
'targets': list([
]),
}),
'language': 'en',
'response_type': 'action_done',
'speech': dict({
}),
}),
}),
'processed_locally': True,
}),
'type': <PipelineEventType.INTENT_END: 'intent-end'>,
}),
dict({
'data': None,
'type': <PipelineEventType.RUN_END: 'run-end'>,
}),
])
# ---
# name: test_wake_word_detection_aborted
list([
dict({
'data': dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'tts_output': dict({
'mime_type': 'audio/mpeg',
'token': 'mocked-token.mp3',
'url': '/api/tts_proxy/mocked-token.mp3',
}),
}),
'type': <PipelineEventType.RUN_START: 'run-start'>,
}),
dict({
'data': dict({
'entity_id': 'wake_word.test',
'metadata': dict({
'bit_rate': <AudioBitRates.BITRATE_16: 16>,
'channel': <AudioChannels.CHANNEL_MONO: 1>,
'codec': <AudioCodecs.PCM: 'pcm'>,
'format': <AudioFormats.WAV: 'wav'>,
'sample_rate': <AudioSampleRates.SAMPLERATE_16000: 16000>,
}),
'timeout': 0,
}),
'type': <PipelineEventType.WAKE_WORD_START: 'wake_word-start'>,
}),
dict({
'data': dict({
'code': 'wake_word_detection_aborted',
'message': '',
}),
'type': <PipelineEventType.ERROR: 'error'>,
}),
dict({
'data': None,
'type': <PipelineEventType.RUN_END: 'run-end'>,
}),
])
# ---

View File

@ -0,0 +1,202 @@
# serializer version: 1
# name: test_pipeline_language_used_instead_of_conversation_language
list([
dict({
'data': dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
}),
'type': <PipelineEventType.RUN_START: 'run-start'>,
}),
dict({
'data': dict({
'conversation_id': 'mock-ulid',
'device_id': None,
'engine': 'conversation.home_assistant',
'intent_input': 'test input',
'language': 'en',
'prefer_local_intents': False,
}),
'type': <PipelineEventType.INTENT_START: 'intent-start'>,
}),
dict({
'data': dict({
'intent_output': dict({
'continue_conversation': False,
'conversation_id': <ANY>,
'response': dict({
'card': dict({
}),
'data': dict({
'failed': list([
]),
'success': list([
]),
'targets': list([
]),
}),
'language': 'en',
'response_type': 'action_done',
'speech': dict({
}),
}),
}),
'processed_locally': True,
}),
'type': <PipelineEventType.INTENT_END: 'intent-end'>,
}),
dict({
'data': None,
'type': <PipelineEventType.RUN_END: 'run-end'>,
}),
])
# ---
# name: test_stt_language_used_instead_of_conversation_language
list([
dict({
'data': dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
}),
'type': <PipelineEventType.RUN_START: 'run-start'>,
}),
dict({
'data': dict({
'conversation_id': 'mock-ulid',
'device_id': None,
'engine': 'conversation.home_assistant',
'intent_input': 'test input',
'language': 'en-US',
'prefer_local_intents': False,
}),
'type': <PipelineEventType.INTENT_START: 'intent-start'>,
}),
dict({
'data': dict({
'intent_output': dict({
'continue_conversation': False,
'conversation_id': <ANY>,
'response': dict({
'card': dict({
}),
'data': dict({
'failed': list([
]),
'success': list([
]),
'targets': list([
]),
}),
'language': 'en',
'response_type': 'action_done',
'speech': dict({
}),
}),
}),
'processed_locally': True,
}),
'type': <PipelineEventType.INTENT_END: 'intent-end'>,
}),
dict({
'data': None,
'type': <PipelineEventType.RUN_END: 'run-end'>,
}),
])
# ---
# name: test_tts_language_used_instead_of_conversation_language
list([
dict({
'data': dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
}),
'type': <PipelineEventType.RUN_START: 'run-start'>,
}),
dict({
'data': dict({
'conversation_id': 'mock-ulid',
'device_id': None,
'engine': 'conversation.home_assistant',
'intent_input': 'test input',
'language': 'en-us',
'prefer_local_intents': False,
}),
'type': <PipelineEventType.INTENT_START: 'intent-start'>,
}),
dict({
'data': dict({
'intent_output': dict({
'continue_conversation': False,
'conversation_id': <ANY>,
'response': dict({
'card': dict({
}),
'data': dict({
'failed': list([
]),
'success': list([
]),
'targets': list([
]),
}),
'language': 'en',
'response_type': 'action_done',
'speech': dict({
}),
}),
}),
'processed_locally': True,
}),
'type': <PipelineEventType.INTENT_END: 'intent-end'>,
}),
dict({
'data': None,
'type': <PipelineEventType.RUN_END: 'run-end'>,
}),
])
# ---
# name: test_wake_word_detection_aborted
list([
dict({
'data': dict({
'conversation_id': 'mock-ulid',
'language': 'en',
'pipeline': <ANY>,
'tts_output': dict({
'mime_type': 'audio/mpeg',
'token': 'mocked-token.mp3',
'url': '/api/tts_proxy/mocked-token.mp3',
}),
}),
'type': <PipelineEventType.RUN_START: 'run-start'>,
}),
dict({
'data': dict({
'entity_id': 'wake_word.test',
'metadata': dict({
'bit_rate': <AudioBitRates.BITRATE_16: 16>,
'channel': <AudioChannels.CHANNEL_MONO: 1>,
'codec': <AudioCodecs.PCM: 'pcm'>,
'format': <AudioFormats.WAV: 'wav'>,
'sample_rate': <AudioSampleRates.SAMPLERATE_16000: 16000>,
}),
'timeout': 0,
}),
'type': <PipelineEventType.WAKE_WORD_START: 'wake_word-start'>,
}),
dict({
'data': dict({
'code': 'wake_word_detection_aborted',
'message': '',
}),
'type': <PipelineEventType.ERROR: 'error'>,
}),
dict({
'data': None,
'type': <PipelineEventType.RUN_END: 'run-end'>,
}),
])
# ---

View File

@ -2,44 +2,35 @@
import asyncio
from collections.abc import Generator
from dataclasses import asdict
import itertools as it
from pathlib import Path
import tempfile
from unittest.mock import ANY, Mock, patch
from unittest.mock import Mock, patch
import wave
import hass_nabucasa
import pytest
from syrupy.assertion import SnapshotAssertion
from homeassistant.components import (
assist_pipeline,
conversation,
media_source,
stt,
tts,
)
from homeassistant.components import assist_pipeline, stt
from homeassistant.components.assist_pipeline.const import (
BYTES_PER_CHUNK,
CONF_DEBUG_RECORDING_DIR,
DOMAIN,
)
from homeassistant.const import MATCH_ALL
from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers import chat_session, intent
from homeassistant.setup import async_setup_component
from . import process_events
from .conftest import (
BYTES_ONE_SECOND,
MockSTTProvider,
MockSTTProviderEntity,
MockTTSProvider,
MockWakeWordEntity,
make_10ms_chunk,
)
from tests.typing import ClientSessionGenerator, WebSocketGenerator
from tests.typing import WebSocketGenerator
@pytest.fixture(autouse=True)
@ -58,19 +49,6 @@ def mock_tts_token() -> Generator[None]:
yield
def process_events(events: list[assist_pipeline.PipelineEvent]) -> list[dict]:
"""Process events to remove dynamic values."""
processed = []
for event in events:
as_dict = asdict(event)
as_dict.pop("timestamp")
if as_dict["type"] == assist_pipeline.PipelineEventType.RUN_START:
as_dict["data"]["pipeline"] = ANY
processed.append(as_dict)
return processed
async def test_pipeline_from_audio_stream_auto(
hass: HomeAssistant,
mock_stt_provider_entity: MockSTTProviderEntity,
@ -677,823 +655,6 @@ async def test_pipeline_saved_audio_empty_queue(
)
async def test_wake_word_detection_aborted(
hass: HomeAssistant,
mock_stt_provider: MockSTTProvider,
mock_wake_word_provider_entity: MockWakeWordEntity,
init_components,
pipeline_data: assist_pipeline.pipeline.PipelineData,
mock_chat_session: chat_session.ChatSession,
snapshot: SnapshotAssertion,
) -> None:
"""Test creating a pipeline from an audio stream with wake word."""
events: list[assist_pipeline.PipelineEvent] = []
async def audio_data():
yield make_10ms_chunk(b"silence!")
yield make_10ms_chunk(b"wake word!")
yield make_10ms_chunk(b"part1")
yield make_10ms_chunk(b"part2")
yield b""
pipeline_store = pipeline_data.pipeline_store
pipeline_id = pipeline_store.async_get_preferred_item()
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
session=mock_chat_session,
device_id=None,
stt_metadata=stt.SpeechMetadata(
language="",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=audio_data(),
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
end_stage=assist_pipeline.PipelineStage.TTS,
event_callback=events.append,
tts_audio_output=None,
wake_word_settings=assist_pipeline.WakeWordSettings(
audio_seconds_to_buffer=1.5
),
audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False),
),
)
await pipeline_input.validate()
updates = pipeline.to_json()
updates.pop("id")
await pipeline_store.async_update_item(
pipeline_id,
updates,
)
await pipeline_input.execute()
assert process_events(events) == snapshot
def test_pipeline_run_equality(hass: HomeAssistant, init_components) -> None:
"""Test that pipeline run equality uses unique id."""
def event_callback(event):
pass
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass)
run_1 = assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.STT,
end_stage=assist_pipeline.PipelineStage.TTS,
event_callback=event_callback,
)
run_2 = assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.STT,
end_stage=assist_pipeline.PipelineStage.TTS,
event_callback=event_callback,
)
assert run_1 == run_1 # noqa: PLR0124
assert run_1 != run_2
assert run_1 != 1234
async def test_tts_audio_output(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
mock_tts_provider: MockTTSProvider,
init_components,
pipeline_data: assist_pipeline.pipeline.PipelineData,
mock_chat_session: chat_session.ChatSession,
snapshot: SnapshotAssertion,
) -> None:
"""Test using tts_audio_output with wav sets options correctly."""
client = await hass_client()
assert await async_setup_component(hass, media_source.DOMAIN, {})
events: list[assist_pipeline.PipelineEvent] = []
pipeline_store = pipeline_data.pipeline_store
pipeline_id = pipeline_store.async_get_preferred_item()
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
tts_input="This is a test.",
session=mock_chat_session,
device_id=None,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.TTS,
end_stage=assist_pipeline.PipelineStage.TTS,
event_callback=events.append,
tts_audio_output="wav",
),
)
await pipeline_input.validate()
# Verify TTS audio settings
assert pipeline_input.run.tts_stream.options is not None
assert pipeline_input.run.tts_stream.options.get(tts.ATTR_PREFERRED_FORMAT) == "wav"
assert (
pipeline_input.run.tts_stream.options.get(tts.ATTR_PREFERRED_SAMPLE_RATE)
== 16000
)
assert (
pipeline_input.run.tts_stream.options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS)
== 1
)
with patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio:
await pipeline_input.execute()
for event in events:
if event.type == assist_pipeline.PipelineEventType.TTS_END:
# We must fetch the media URL to trigger the TTS
assert event.data
await client.get(event.data["tts_output"]["url"])
# Ensure that no unsupported options were passed in
assert mock_get_tts_audio.called
options = mock_get_tts_audio.call_args_list[0].kwargs["options"]
extra_options = set(options).difference(mock_tts_provider.supported_options)
assert len(extra_options) == 0, extra_options
async def test_tts_wav_preferred_format(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
mock_tts_provider: MockTTSProvider,
init_components,
mock_chat_session: chat_session.ChatSession,
pipeline_data: assist_pipeline.pipeline.PipelineData,
) -> None:
"""Test that preferred format options are given to the TTS system if supported."""
client = await hass_client()
assert await async_setup_component(hass, media_source.DOMAIN, {})
events: list[assist_pipeline.PipelineEvent] = []
pipeline_store = pipeline_data.pipeline_store
pipeline_id = pipeline_store.async_get_preferred_item()
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
tts_input="This is a test.",
session=mock_chat_session,
device_id=None,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.TTS,
end_stage=assist_pipeline.PipelineStage.TTS,
event_callback=events.append,
tts_audio_output="wav",
),
)
await pipeline_input.validate()
# Make the TTS provider support preferred format options
supported_options = list(mock_tts_provider.supported_options or [])
supported_options.extend(
[
tts.ATTR_PREFERRED_FORMAT,
tts.ATTR_PREFERRED_SAMPLE_RATE,
tts.ATTR_PREFERRED_SAMPLE_CHANNELS,
tts.ATTR_PREFERRED_SAMPLE_BYTES,
]
)
with (
patch.object(mock_tts_provider, "_supported_options", supported_options),
patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio,
):
await pipeline_input.execute()
for event in events:
if event.type == assist_pipeline.PipelineEventType.TTS_END:
# We must fetch the media URL to trigger the TTS
assert event.data
await client.get(event.data["tts_output"]["url"])
assert mock_get_tts_audio.called
options = mock_get_tts_audio.call_args_list[0].kwargs["options"]
# We should have received preferred format options in get_tts_audio
assert options.get(tts.ATTR_PREFERRED_FORMAT) == "wav"
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_RATE)) == 16000
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS)) == 1
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_BYTES)) == 2
async def test_tts_dict_preferred_format(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
mock_tts_provider: MockTTSProvider,
init_components,
mock_chat_session: chat_session.ChatSession,
pipeline_data: assist_pipeline.pipeline.PipelineData,
) -> None:
"""Test that preferred format options are given to the TTS system if supported."""
client = await hass_client()
assert await async_setup_component(hass, media_source.DOMAIN, {})
events: list[assist_pipeline.PipelineEvent] = []
pipeline_store = pipeline_data.pipeline_store
pipeline_id = pipeline_store.async_get_preferred_item()
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
tts_input="This is a test.",
session=mock_chat_session,
device_id=None,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.TTS,
end_stage=assist_pipeline.PipelineStage.TTS,
event_callback=events.append,
tts_audio_output={
tts.ATTR_PREFERRED_FORMAT: "flac",
tts.ATTR_PREFERRED_SAMPLE_RATE: 48000,
tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 2,
tts.ATTR_PREFERRED_SAMPLE_BYTES: 2,
},
),
)
await pipeline_input.validate()
# Make the TTS provider support preferred format options
supported_options = list(mock_tts_provider.supported_options or [])
supported_options.extend(
[
tts.ATTR_PREFERRED_FORMAT,
tts.ATTR_PREFERRED_SAMPLE_RATE,
tts.ATTR_PREFERRED_SAMPLE_CHANNELS,
tts.ATTR_PREFERRED_SAMPLE_BYTES,
]
)
with (
patch.object(mock_tts_provider, "_supported_options", supported_options),
patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio,
):
await pipeline_input.execute()
for event in events:
if event.type == assist_pipeline.PipelineEventType.TTS_END:
# We must fetch the media URL to trigger the TTS
assert event.data
await client.get(event.data["tts_output"]["url"])
assert mock_get_tts_audio.called
options = mock_get_tts_audio.call_args_list[0].kwargs["options"]
# We should have received preferred format options in get_tts_audio
assert options.get(tts.ATTR_PREFERRED_FORMAT) == "flac"
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_RATE)) == 48000
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS)) == 2
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_BYTES)) == 2
async def test_sentence_trigger_overrides_conversation_agent(
hass: HomeAssistant,
init_components,
mock_chat_session: chat_session.ChatSession,
pipeline_data: assist_pipeline.pipeline.PipelineData,
) -> None:
"""Test that sentence triggers are checked before a non-default conversation agent."""
assert await async_setup_component(
hass,
"automation",
{
"automation": {
"trigger": {
"platform": "conversation",
"command": [
"test trigger sentence",
],
},
"action": {
"set_conversation_response": "test trigger response",
},
}
},
)
events: list[assist_pipeline.PipelineEvent] = []
pipeline_store = pipeline_data.pipeline_store
pipeline_id = pipeline_store.async_get_preferred_item()
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="test trigger sentence",
session=mock_chat_session,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.INTENT,
end_stage=assist_pipeline.PipelineStage.INTENT,
event_callback=events.append,
intent_agent="test-agent", # not the default agent
),
)
# Ensure prepare succeeds
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"),
):
await pipeline_input.validate()
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse"
) as mock_async_converse:
await pipeline_input.execute()
# Sentence trigger should have been handled
mock_async_converse.assert_not_called()
# Verify sentence trigger response
intent_end_event = next(
(
e
for e in events
if e.type == assist_pipeline.PipelineEventType.INTENT_END
),
None,
)
assert (intent_end_event is not None) and intent_end_event.data
assert (
intent_end_event.data["intent_output"]["response"]["speech"]["plain"][
"speech"
]
== "test trigger response"
)
async def test_prefer_local_intents(
hass: HomeAssistant,
init_components,
mock_chat_session: chat_session.ChatSession,
pipeline_data: assist_pipeline.pipeline.PipelineData,
) -> None:
"""Test that the default agent is checked first when local intents are preferred."""
events: list[assist_pipeline.PipelineEvent] = []
# Reuse custom sentences in test config
class OrderBeerIntentHandler(intent.IntentHandler):
intent_type = "OrderBeer"
async def async_handle(
self, intent_obj: intent.Intent
) -> intent.IntentResponse:
response = intent_obj.create_response()
response.async_set_speech("Order confirmed")
return response
handler = OrderBeerIntentHandler()
intent.async_register(hass, handler)
# Fake a test agent and prefer local intents
pipeline_store = pipeline_data.pipeline_store
pipeline_id = pipeline_store.async_get_preferred_item()
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
await assist_pipeline.pipeline.async_update_pipeline(
hass, pipeline, conversation_engine="test-agent", prefer_local_intents=True
)
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="I'd like to order a stout please",
session=mock_chat_session,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.INTENT,
end_stage=assist_pipeline.PipelineStage.INTENT,
event_callback=events.append,
),
)
# Ensure prepare succeeds
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"),
):
await pipeline_input.validate()
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse"
) as mock_async_converse:
await pipeline_input.execute()
# Test agent should not have been called
mock_async_converse.assert_not_called()
# Verify local intent response
intent_end_event = next(
(
e
for e in events
if e.type == assist_pipeline.PipelineEventType.INTENT_END
),
None,
)
assert (intent_end_event is not None) and intent_end_event.data
assert (
intent_end_event.data["intent_output"]["response"]["speech"]["plain"][
"speech"
]
== "Order confirmed"
)
async def test_intent_continue_conversation(
hass: HomeAssistant,
init_components,
mock_chat_session: chat_session.ChatSession,
pipeline_data: assist_pipeline.pipeline.PipelineData,
) -> None:
"""Test that a conversation agent flagging continue conversation gets response."""
events: list[assist_pipeline.PipelineEvent] = []
# Fake a test agent and prefer local intents
pipeline_store = pipeline_data.pipeline_store
pipeline_id = pipeline_store.async_get_preferred_item()
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
await assist_pipeline.pipeline.async_update_pipeline(
hass, pipeline, conversation_engine="test-agent"
)
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="Set a timer",
session=mock_chat_session,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.INTENT,
end_stage=assist_pipeline.PipelineStage.INTENT,
event_callback=events.append,
),
)
# Ensure prepare succeeds
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"),
):
await pipeline_input.validate()
response = intent.IntentResponse("en")
response.async_set_speech("For how long?")
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
return_value=conversation.ConversationResult(
response=response,
conversation_id=mock_chat_session.conversation_id,
continue_conversation=True,
),
) as mock_async_converse:
await pipeline_input.execute()
mock_async_converse.assert_called()
results = [
event.data
for event in events
if event.type
in (
assist_pipeline.PipelineEventType.INTENT_START,
assist_pipeline.PipelineEventType.INTENT_END,
)
]
assert results[1]["intent_output"]["continue_conversation"] is True
# Change conversation agent to default one and register sentence trigger that should not be called
await assist_pipeline.pipeline.async_update_pipeline(
hass, pipeline, conversation_engine=None
)
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
assert await async_setup_component(
hass,
"automation",
{
"automation": {
"trigger": {
"platform": "conversation",
"command": ["Hello"],
},
"action": {
"set_conversation_response": "test trigger response",
},
}
},
)
# Because we did continue conversation, it should respond to the test agent again.
events.clear()
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="Hello",
session=mock_chat_session,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.INTENT,
end_stage=assist_pipeline.PipelineStage.INTENT,
event_callback=events.append,
),
)
# Ensure prepare succeeds
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"),
) as mock_prepare:
await pipeline_input.validate()
# It requested test agent even if that was not default agent.
assert mock_prepare.mock_calls[0][1][1] == "test-agent"
response = intent.IntentResponse("en")
response.async_set_speech("Timer set for 20 minutes")
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
return_value=conversation.ConversationResult(
response=response,
conversation_id=mock_chat_session.conversation_id,
),
) as mock_async_converse:
await pipeline_input.execute()
mock_async_converse.assert_called()
# Snapshot will show it was still handled by the test agent and not default agent
results = [
event.data
for event in events
if event.type
in (
assist_pipeline.PipelineEventType.INTENT_START,
assist_pipeline.PipelineEventType.INTENT_END,
)
]
assert results[0]["engine"] == "test-agent"
assert results[1]["intent_output"]["continue_conversation"] is False
async def test_stt_language_used_instead_of_conversation_language(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
mock_chat_session: chat_session.ChatSession,
snapshot: SnapshotAssertion,
) -> None:
"""Test that the STT language is used first when the conversation language is '*' (all languages)."""
client = await hass_ws_client(hass)
events: list[assist_pipeline.PipelineEvent] = []
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline/create",
"conversation_engine": "homeassistant",
"conversation_language": MATCH_ALL,
"language": "en",
"name": "test_name",
"stt_engine": "test",
"stt_language": "en-US",
"tts_engine": "test",
"tts_language": "en-US",
"tts_voice": "Arnold Schwarzenegger",
"wake_word_entity": None,
"wake_word_id": None,
}
)
msg = await client.receive_json()
assert msg["success"]
pipeline_id = msg["result"]["id"]
pipeline = assist_pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="test input",
session=mock_chat_session,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.INTENT,
end_stage=assist_pipeline.PipelineStage.INTENT,
event_callback=events.append,
),
)
await pipeline_input.validate()
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
return_value=conversation.ConversationResult(
intent.IntentResponse(pipeline.language)
),
) as mock_async_converse:
await pipeline_input.execute()
# Check intent start event
assert process_events(events) == snapshot
intent_start: assist_pipeline.PipelineEvent | None = None
for event in events:
if event.type == assist_pipeline.PipelineEventType.INTENT_START:
intent_start = event
break
assert intent_start is not None
# STT language (en-US) should be used instead of '*'
assert intent_start.data.get("language") == pipeline.stt_language
# Check input to async_converse
mock_async_converse.assert_called_once()
assert (
mock_async_converse.call_args_list[0].kwargs.get("language")
== pipeline.stt_language
)
async def test_tts_language_used_instead_of_conversation_language(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
mock_chat_session: chat_session.ChatSession,
snapshot: SnapshotAssertion,
) -> None:
"""Test that the TTS language is used after STT when the conversation language is '*' (all languages)."""
client = await hass_ws_client(hass)
events: list[assist_pipeline.PipelineEvent] = []
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline/create",
"conversation_engine": "homeassistant",
"conversation_language": MATCH_ALL,
"language": "en",
"name": "test_name",
"stt_engine": None,
"stt_language": None,
"tts_engine": None,
"tts_language": "en-us",
"tts_voice": "Arnold Schwarzenegger",
"wake_word_entity": None,
"wake_word_id": None,
}
)
msg = await client.receive_json()
assert msg["success"]
pipeline_id = msg["result"]["id"]
pipeline = assist_pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="test input",
session=mock_chat_session,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.INTENT,
end_stage=assist_pipeline.PipelineStage.INTENT,
event_callback=events.append,
),
)
await pipeline_input.validate()
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
return_value=conversation.ConversationResult(
intent.IntentResponse(pipeline.language)
),
) as mock_async_converse:
await pipeline_input.execute()
# Check intent start event
assert process_events(events) == snapshot
intent_start: assist_pipeline.PipelineEvent | None = None
for event in events:
if event.type == assist_pipeline.PipelineEventType.INTENT_START:
intent_start = event
break
assert intent_start is not None
# STT language (en-US) should be used instead of '*'
assert intent_start.data.get("language") == pipeline.tts_language
# Check input to async_converse
mock_async_converse.assert_called_once()
assert (
mock_async_converse.call_args_list[0].kwargs.get("language")
== pipeline.tts_language
)
async def test_pipeline_language_used_instead_of_conversation_language(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
mock_chat_session: chat_session.ChatSession,
snapshot: SnapshotAssertion,
) -> None:
"""Test that the pipeline language is used last when the conversation language is '*' (all languages)."""
client = await hass_ws_client(hass)
events: list[assist_pipeline.PipelineEvent] = []
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline/create",
"conversation_engine": "homeassistant",
"conversation_language": MATCH_ALL,
"language": "en",
"name": "test_name",
"stt_engine": None,
"stt_language": None,
"tts_engine": None,
"tts_language": None,
"tts_voice": None,
"wake_word_entity": None,
"wake_word_id": None,
}
)
msg = await client.receive_json()
assert msg["success"]
pipeline_id = msg["result"]["id"]
pipeline = assist_pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="test input",
session=mock_chat_session,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.INTENT,
end_stage=assist_pipeline.PipelineStage.INTENT,
event_callback=events.append,
),
)
await pipeline_input.validate()
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
return_value=conversation.ConversationResult(
intent.IntentResponse(pipeline.language)
),
) as mock_async_converse:
await pipeline_input.execute()
# Check intent start event
assert process_events(events) == snapshot
intent_start: assist_pipeline.PipelineEvent | None = None
for event in events:
if event.type == assist_pipeline.PipelineEventType.INTENT_START:
intent_start = event
break
assert intent_start is not None
# STT language (en-US) should be used instead of '*'
assert intent_start.data.get("language") == pipeline.language
# Check input to async_converse
mock_async_converse.assert_called_once()
assert (
mock_async_converse.call_args_list[0].kwargs.get("language")
== pipeline.language
)
async def test_pipeline_from_audio_stream_with_cloud_auth_fail(
hass: HomeAssistant,
mock_stt_provider_entity: MockSTTProviderEntity,

View File

@ -1,13 +1,20 @@
"""Websocket tests for Voice Assistant integration."""
from collections.abc import AsyncGenerator
from collections.abc import AsyncGenerator, Generator
from typing import Any
from unittest.mock import ANY, patch
from unittest.mock import ANY, Mock, patch
from hassil.recognize import Intent, IntentData, RecognizeResult
import pytest
from syrupy.assertion import SnapshotAssertion
from homeassistant.components import conversation
from homeassistant.components import (
assist_pipeline,
conversation,
media_source,
stt,
tts,
)
from homeassistant.components.assist_pipeline.const import DOMAIN
from homeassistant.components.assist_pipeline.pipeline import (
STORAGE_KEY,
@ -24,14 +31,22 @@ from homeassistant.components.assist_pipeline.pipeline import (
async_migrate_engine,
async_update_pipeline,
)
from homeassistant.core import HomeAssistant
from homeassistant.helpers import intent
from homeassistant.const import MATCH_ALL
from homeassistant.core import Context, HomeAssistant
from homeassistant.helpers import chat_session, intent
from homeassistant.setup import async_setup_component
from . import MANY_LANGUAGES
from .conftest import MockSTTProviderEntity, MockTTSProvider
from . import MANY_LANGUAGES, process_events
from .conftest import (
MockSTTProvider,
MockSTTProviderEntity,
MockTTSProvider,
MockWakeWordEntity,
make_10ms_chunk,
)
from tests.common import flush_store
from tests.typing import ClientSessionGenerator, WebSocketGenerator
@pytest.fixture(autouse=True)
@ -119,6 +134,22 @@ async def test_load_pipelines(hass: HomeAssistant) -> None:
assert store1.async_get_preferred_item() == store2.async_get_preferred_item()
@pytest.fixture(autouse=True)
def mock_chat_session_id() -> Generator[Mock]:
"""Mock the conversation ID of chat sessions."""
with patch(
"homeassistant.helpers.chat_session.ulid_now", return_value="mock-ulid"
) as mock_ulid_now:
yield mock_ulid_now
@pytest.fixture(autouse=True)
def mock_tts_token() -> Generator[None]:
"""Mock the TTS token for URLs."""
with patch("secrets.token_urlsafe", return_value="mocked-token"):
yield
async def test_loading_pipelines_from_storage(
hass: HomeAssistant, hass_storage: dict[str, Any]
) -> None:
@ -697,3 +728,820 @@ def test_fallback_intent_filter() -> None:
)
is False
)
async def test_wake_word_detection_aborted(
hass: HomeAssistant,
mock_stt_provider: MockSTTProvider,
mock_wake_word_provider_entity: MockWakeWordEntity,
init_components,
pipeline_data: assist_pipeline.pipeline.PipelineData,
mock_chat_session: chat_session.ChatSession,
snapshot: SnapshotAssertion,
) -> None:
"""Test wake word stream is first detected, then aborted."""
events: list[assist_pipeline.PipelineEvent] = []
async def audio_data():
yield make_10ms_chunk(b"silence!")
yield make_10ms_chunk(b"wake word!")
yield make_10ms_chunk(b"part1")
yield make_10ms_chunk(b"part2")
yield b""
pipeline_store = pipeline_data.pipeline_store
pipeline_id = pipeline_store.async_get_preferred_item()
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
session=mock_chat_session,
device_id=None,
stt_metadata=stt.SpeechMetadata(
language="",
format=stt.AudioFormats.WAV,
codec=stt.AudioCodecs.PCM,
bit_rate=stt.AudioBitRates.BITRATE_16,
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
channel=stt.AudioChannels.CHANNEL_MONO,
),
stt_stream=audio_data(),
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
end_stage=assist_pipeline.PipelineStage.TTS,
event_callback=events.append,
tts_audio_output=None,
wake_word_settings=assist_pipeline.WakeWordSettings(
audio_seconds_to_buffer=1.5
),
audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False),
),
)
await pipeline_input.validate()
updates = pipeline.to_json()
updates.pop("id")
await pipeline_store.async_update_item(
pipeline_id,
updates,
)
await pipeline_input.execute()
assert process_events(events) == snapshot
def test_pipeline_run_equality(hass: HomeAssistant, init_components) -> None:
"""Test that pipeline run equality uses unique id."""
def event_callback(event):
pass
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass)
run_1 = assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.STT,
end_stage=assist_pipeline.PipelineStage.TTS,
event_callback=event_callback,
)
run_2 = assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.STT,
end_stage=assist_pipeline.PipelineStage.TTS,
event_callback=event_callback,
)
assert run_1 == run_1 # noqa: PLR0124
assert run_1 != run_2
assert run_1 != 1234
async def test_tts_audio_output(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
mock_tts_provider: MockTTSProvider,
init_components,
pipeline_data: assist_pipeline.pipeline.PipelineData,
mock_chat_session: chat_session.ChatSession,
snapshot: SnapshotAssertion,
) -> None:
"""Test using tts_audio_output with wav sets options correctly."""
client = await hass_client()
assert await async_setup_component(hass, media_source.DOMAIN, {})
events: list[assist_pipeline.PipelineEvent] = []
pipeline_store = pipeline_data.pipeline_store
pipeline_id = pipeline_store.async_get_preferred_item()
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
tts_input="This is a test.",
session=mock_chat_session,
device_id=None,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.TTS,
end_stage=assist_pipeline.PipelineStage.TTS,
event_callback=events.append,
tts_audio_output="wav",
),
)
await pipeline_input.validate()
# Verify TTS audio settings
assert pipeline_input.run.tts_stream.options is not None
assert pipeline_input.run.tts_stream.options.get(tts.ATTR_PREFERRED_FORMAT) == "wav"
assert (
pipeline_input.run.tts_stream.options.get(tts.ATTR_PREFERRED_SAMPLE_RATE)
== 16000
)
assert (
pipeline_input.run.tts_stream.options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS)
== 1
)
with patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio:
await pipeline_input.execute()
for event in events:
if event.type == assist_pipeline.PipelineEventType.TTS_END:
# We must fetch the media URL to trigger the TTS
assert event.data
await client.get(event.data["tts_output"]["url"])
# Ensure that no unsupported options were passed in
assert mock_get_tts_audio.called
options = mock_get_tts_audio.call_args_list[0].kwargs["options"]
extra_options = set(options).difference(mock_tts_provider.supported_options)
assert len(extra_options) == 0, extra_options
async def test_tts_wav_preferred_format(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
mock_tts_provider: MockTTSProvider,
init_components,
mock_chat_session: chat_session.ChatSession,
pipeline_data: assist_pipeline.pipeline.PipelineData,
) -> None:
"""Test that preferred format options are given to the TTS system if supported."""
client = await hass_client()
assert await async_setup_component(hass, media_source.DOMAIN, {})
events: list[assist_pipeline.PipelineEvent] = []
pipeline_store = pipeline_data.pipeline_store
pipeline_id = pipeline_store.async_get_preferred_item()
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
tts_input="This is a test.",
session=mock_chat_session,
device_id=None,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.TTS,
end_stage=assist_pipeline.PipelineStage.TTS,
event_callback=events.append,
tts_audio_output="wav",
),
)
await pipeline_input.validate()
# Make the TTS provider support preferred format options
supported_options = list(mock_tts_provider.supported_options or [])
supported_options.extend(
[
tts.ATTR_PREFERRED_FORMAT,
tts.ATTR_PREFERRED_SAMPLE_RATE,
tts.ATTR_PREFERRED_SAMPLE_CHANNELS,
tts.ATTR_PREFERRED_SAMPLE_BYTES,
]
)
with (
patch.object(mock_tts_provider, "_supported_options", supported_options),
patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio,
):
await pipeline_input.execute()
for event in events:
if event.type == assist_pipeline.PipelineEventType.TTS_END:
# We must fetch the media URL to trigger the TTS
assert event.data
await client.get(event.data["tts_output"]["url"])
assert mock_get_tts_audio.called
options = mock_get_tts_audio.call_args_list[0].kwargs["options"]
# We should have received preferred format options in get_tts_audio
assert options.get(tts.ATTR_PREFERRED_FORMAT) == "wav"
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_RATE)) == 16000
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS)) == 1
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_BYTES)) == 2
async def test_tts_dict_preferred_format(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
mock_tts_provider: MockTTSProvider,
init_components,
mock_chat_session: chat_session.ChatSession,
pipeline_data: assist_pipeline.pipeline.PipelineData,
) -> None:
"""Test that preferred format options are given to the TTS system if supported."""
client = await hass_client()
assert await async_setup_component(hass, media_source.DOMAIN, {})
events: list[assist_pipeline.PipelineEvent] = []
pipeline_store = pipeline_data.pipeline_store
pipeline_id = pipeline_store.async_get_preferred_item()
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
tts_input="This is a test.",
session=mock_chat_session,
device_id=None,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.TTS,
end_stage=assist_pipeline.PipelineStage.TTS,
event_callback=events.append,
tts_audio_output={
tts.ATTR_PREFERRED_FORMAT: "flac",
tts.ATTR_PREFERRED_SAMPLE_RATE: 48000,
tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 2,
tts.ATTR_PREFERRED_SAMPLE_BYTES: 2,
},
),
)
await pipeline_input.validate()
# Make the TTS provider support preferred format options
supported_options = list(mock_tts_provider.supported_options or [])
supported_options.extend(
[
tts.ATTR_PREFERRED_FORMAT,
tts.ATTR_PREFERRED_SAMPLE_RATE,
tts.ATTR_PREFERRED_SAMPLE_CHANNELS,
tts.ATTR_PREFERRED_SAMPLE_BYTES,
]
)
with (
patch.object(mock_tts_provider, "_supported_options", supported_options),
patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio,
):
await pipeline_input.execute()
for event in events:
if event.type == assist_pipeline.PipelineEventType.TTS_END:
# We must fetch the media URL to trigger the TTS
assert event.data
await client.get(event.data["tts_output"]["url"])
assert mock_get_tts_audio.called
options = mock_get_tts_audio.call_args_list[0].kwargs["options"]
# We should have received preferred format options in get_tts_audio
assert options.get(tts.ATTR_PREFERRED_FORMAT) == "flac"
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_RATE)) == 48000
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS)) == 2
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_BYTES)) == 2
async def test_sentence_trigger_overrides_conversation_agent(
hass: HomeAssistant,
init_components,
mock_chat_session: chat_session.ChatSession,
pipeline_data: assist_pipeline.pipeline.PipelineData,
) -> None:
"""Test that sentence triggers are checked before a non-default conversation agent."""
assert await async_setup_component(
hass,
"automation",
{
"automation": {
"trigger": {
"platform": "conversation",
"command": [
"test trigger sentence",
],
},
"action": {
"set_conversation_response": "test trigger response",
},
}
},
)
events: list[assist_pipeline.PipelineEvent] = []
pipeline_store = pipeline_data.pipeline_store
pipeline_id = pipeline_store.async_get_preferred_item()
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="test trigger sentence",
session=mock_chat_session,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.INTENT,
end_stage=assist_pipeline.PipelineStage.INTENT,
event_callback=events.append,
intent_agent="test-agent", # not the default agent
),
)
# Ensure prepare succeeds
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"),
):
await pipeline_input.validate()
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse"
) as mock_async_converse:
await pipeline_input.execute()
# Sentence trigger should have been handled
mock_async_converse.assert_not_called()
# Verify sentence trigger response
intent_end_event = next(
(
e
for e in events
if e.type == assist_pipeline.PipelineEventType.INTENT_END
),
None,
)
assert (intent_end_event is not None) and intent_end_event.data
assert (
intent_end_event.data["intent_output"]["response"]["speech"]["plain"][
"speech"
]
== "test trigger response"
)
async def test_prefer_local_intents(
hass: HomeAssistant,
init_components,
mock_chat_session: chat_session.ChatSession,
pipeline_data: assist_pipeline.pipeline.PipelineData,
) -> None:
"""Test that the default agent is checked first when local intents are preferred."""
events: list[assist_pipeline.PipelineEvent] = []
# Reuse custom sentences in test config
class OrderBeerIntentHandler(intent.IntentHandler):
intent_type = "OrderBeer"
async def async_handle(
self, intent_obj: intent.Intent
) -> intent.IntentResponse:
response = intent_obj.create_response()
response.async_set_speech("Order confirmed")
return response
handler = OrderBeerIntentHandler()
intent.async_register(hass, handler)
# Fake a test agent and prefer local intents
pipeline_store = pipeline_data.pipeline_store
pipeline_id = pipeline_store.async_get_preferred_item()
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
await assist_pipeline.pipeline.async_update_pipeline(
hass, pipeline, conversation_engine="test-agent", prefer_local_intents=True
)
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="I'd like to order a stout please",
session=mock_chat_session,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.INTENT,
end_stage=assist_pipeline.PipelineStage.INTENT,
event_callback=events.append,
),
)
# Ensure prepare succeeds
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"),
):
await pipeline_input.validate()
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse"
) as mock_async_converse:
await pipeline_input.execute()
# Test agent should not have been called
mock_async_converse.assert_not_called()
# Verify local intent response
intent_end_event = next(
(
e
for e in events
if e.type == assist_pipeline.PipelineEventType.INTENT_END
),
None,
)
assert (intent_end_event is not None) and intent_end_event.data
assert (
intent_end_event.data["intent_output"]["response"]["speech"]["plain"][
"speech"
]
== "Order confirmed"
)
async def test_intent_continue_conversation(
hass: HomeAssistant,
init_components,
mock_chat_session: chat_session.ChatSession,
pipeline_data: assist_pipeline.pipeline.PipelineData,
) -> None:
"""Test that a conversation agent flagging continue conversation gets response."""
events: list[assist_pipeline.PipelineEvent] = []
# Fake a test agent and prefer local intents
pipeline_store = pipeline_data.pipeline_store
pipeline_id = pipeline_store.async_get_preferred_item()
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
await assist_pipeline.pipeline.async_update_pipeline(
hass, pipeline, conversation_engine="test-agent"
)
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="Set a timer",
session=mock_chat_session,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.INTENT,
end_stage=assist_pipeline.PipelineStage.INTENT,
event_callback=events.append,
),
)
# Ensure prepare succeeds
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"),
):
await pipeline_input.validate()
response = intent.IntentResponse("en")
response.async_set_speech("For how long?")
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
return_value=conversation.ConversationResult(
response=response,
conversation_id=mock_chat_session.conversation_id,
continue_conversation=True,
),
) as mock_async_converse:
await pipeline_input.execute()
mock_async_converse.assert_called()
results = [
event.data
for event in events
if event.type
in (
assist_pipeline.PipelineEventType.INTENT_START,
assist_pipeline.PipelineEventType.INTENT_END,
)
]
assert results[1]["intent_output"]["continue_conversation"] is True
# Change conversation agent to default one and register sentence trigger that should not be called
await assist_pipeline.pipeline.async_update_pipeline(
hass, pipeline, conversation_engine=None
)
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
assert await async_setup_component(
hass,
"automation",
{
"automation": {
"trigger": {
"platform": "conversation",
"command": ["Hello"],
},
"action": {
"set_conversation_response": "test trigger response",
},
}
},
)
# Because we did continue conversation, it should respond to the test agent again.
events.clear()
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="Hello",
session=mock_chat_session,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.INTENT,
end_stage=assist_pipeline.PipelineStage.INTENT,
event_callback=events.append,
),
)
# Ensure prepare succeeds
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"),
) as mock_prepare:
await pipeline_input.validate()
# It requested test agent even if that was not default agent.
assert mock_prepare.mock_calls[0][1][1] == "test-agent"
response = intent.IntentResponse("en")
response.async_set_speech("Timer set for 20 minutes")
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
return_value=conversation.ConversationResult(
response=response,
conversation_id=mock_chat_session.conversation_id,
),
) as mock_async_converse:
await pipeline_input.execute()
mock_async_converse.assert_called()
# Snapshot will show it was still handled by the test agent and not default agent
results = [
event.data
for event in events
if event.type
in (
assist_pipeline.PipelineEventType.INTENT_START,
assist_pipeline.PipelineEventType.INTENT_END,
)
]
assert results[0]["engine"] == "test-agent"
assert results[1]["intent_output"]["continue_conversation"] is False
async def test_stt_language_used_instead_of_conversation_language(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
mock_chat_session: chat_session.ChatSession,
snapshot: SnapshotAssertion,
) -> None:
"""Test that the STT language is used first when the conversation language is '*' (all languages)."""
client = await hass_ws_client(hass)
events: list[assist_pipeline.PipelineEvent] = []
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline/create",
"conversation_engine": "homeassistant",
"conversation_language": MATCH_ALL,
"language": "en",
"name": "test_name",
"stt_engine": "test",
"stt_language": "en-US",
"tts_engine": "test",
"tts_language": "en-US",
"tts_voice": "Arnold Schwarzenegger",
"wake_word_entity": None,
"wake_word_id": None,
}
)
msg = await client.receive_json()
assert msg["success"]
pipeline_id = msg["result"]["id"]
pipeline = assist_pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="test input",
session=mock_chat_session,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.INTENT,
end_stage=assist_pipeline.PipelineStage.INTENT,
event_callback=events.append,
),
)
await pipeline_input.validate()
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
return_value=conversation.ConversationResult(
intent.IntentResponse(pipeline.language)
),
) as mock_async_converse:
await pipeline_input.execute()
# Check intent start event
assert process_events(events) == snapshot
intent_start: assist_pipeline.PipelineEvent | None = None
for event in events:
if event.type == assist_pipeline.PipelineEventType.INTENT_START:
intent_start = event
break
assert intent_start is not None
# STT language (en-US) should be used instead of '*'
assert intent_start.data.get("language") == pipeline.stt_language
# Check input to async_converse
mock_async_converse.assert_called_once()
assert (
mock_async_converse.call_args_list[0].kwargs.get("language")
== pipeline.stt_language
)
async def test_tts_language_used_instead_of_conversation_language(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
mock_chat_session: chat_session.ChatSession,
snapshot: SnapshotAssertion,
) -> None:
"""Test that the TTS language is used after STT when the conversation language is '*' (all languages)."""
client = await hass_ws_client(hass)
events: list[assist_pipeline.PipelineEvent] = []
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline/create",
"conversation_engine": "homeassistant",
"conversation_language": MATCH_ALL,
"language": "en",
"name": "test_name",
"stt_engine": None,
"stt_language": None,
"tts_engine": None,
"tts_language": "en-us",
"tts_voice": "Arnold Schwarzenegger",
"wake_word_entity": None,
"wake_word_id": None,
}
)
msg = await client.receive_json()
assert msg["success"]
pipeline_id = msg["result"]["id"]
pipeline = assist_pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="test input",
session=mock_chat_session,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.INTENT,
end_stage=assist_pipeline.PipelineStage.INTENT,
event_callback=events.append,
),
)
await pipeline_input.validate()
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
return_value=conversation.ConversationResult(
intent.IntentResponse(pipeline.language)
),
) as mock_async_converse:
await pipeline_input.execute()
# Check intent start event
assert process_events(events) == snapshot
intent_start: assist_pipeline.PipelineEvent | None = None
for event in events:
if event.type == assist_pipeline.PipelineEventType.INTENT_START:
intent_start = event
break
assert intent_start is not None
# STT language (en-US) should be used instead of '*'
assert intent_start.data.get("language") == pipeline.tts_language
# Check input to async_converse
mock_async_converse.assert_called_once()
assert (
mock_async_converse.call_args_list[0].kwargs.get("language")
== pipeline.tts_language
)
async def test_pipeline_language_used_instead_of_conversation_language(
hass: HomeAssistant,
hass_ws_client: WebSocketGenerator,
init_components,
mock_chat_session: chat_session.ChatSession,
snapshot: SnapshotAssertion,
) -> None:
"""Test that the pipeline language is used last when the conversation language is '*' (all languages)."""
client = await hass_ws_client(hass)
events: list[assist_pipeline.PipelineEvent] = []
await client.send_json_auto_id(
{
"type": "assist_pipeline/pipeline/create",
"conversation_engine": "homeassistant",
"conversation_language": MATCH_ALL,
"language": "en",
"name": "test_name",
"stt_engine": None,
"stt_language": None,
"tts_engine": None,
"tts_language": None,
"tts_voice": None,
"wake_word_entity": None,
"wake_word_id": None,
}
)
msg = await client.receive_json()
assert msg["success"]
pipeline_id = msg["result"]["id"]
pipeline = assist_pipeline.async_get_pipeline(hass, pipeline_id)
pipeline_input = assist_pipeline.pipeline.PipelineInput(
intent_input="test input",
session=mock_chat_session,
run=assist_pipeline.pipeline.PipelineRun(
hass,
context=Context(),
pipeline=pipeline,
start_stage=assist_pipeline.PipelineStage.INTENT,
end_stage=assist_pipeline.PipelineStage.INTENT,
event_callback=events.append,
),
)
await pipeline_input.validate()
with patch(
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
return_value=conversation.ConversationResult(
intent.IntentResponse(pipeline.language)
),
) as mock_async_converse:
await pipeline_input.execute()
# Check intent start event
assert process_events(events) == snapshot
intent_start: assist_pipeline.PipelineEvent | None = None
for event in events:
if event.type == assist_pipeline.PipelineEventType.INTENT_START:
intent_start = event
break
assert intent_start is not None
# STT language (en-US) should be used instead of '*'
assert intent_start.data.get("language") == pipeline.language
# Check input to async_converse
mock_async_converse.assert_called_once()
assert (
mock_async_converse.call_args_list[0].kwargs.get("language")
== pipeline.language
)