mirror of
https://github.com/home-assistant/core.git
synced 2025-07-28 15:47:12 +00:00
Move Assist Pipeline tests to right file (#144696)
This commit is contained in:
parent
b394c07a3d
commit
4faa920318
@ -1,5 +1,10 @@
|
||||
"""Tests for the Voice Assistant integration."""
|
||||
|
||||
from dataclasses import asdict
|
||||
from unittest.mock import ANY
|
||||
|
||||
from homeassistant.components import assist_pipeline
|
||||
|
||||
MANY_LANGUAGES = [
|
||||
"ar",
|
||||
"bg",
|
||||
@ -54,3 +59,16 @@ MANY_LANGUAGES = [
|
||||
"zh-hk",
|
||||
"zh-tw",
|
||||
]
|
||||
|
||||
|
||||
def process_events(events: list[assist_pipeline.PipelineEvent]) -> list[dict]:
|
||||
"""Process events to remove dynamic values."""
|
||||
processed = []
|
||||
for event in events:
|
||||
as_dict = asdict(event)
|
||||
as_dict.pop("timestamp")
|
||||
if as_dict["type"] == assist_pipeline.PipelineEventType.RUN_START:
|
||||
as_dict["data"]["pipeline"] = ANY
|
||||
processed.append(as_dict)
|
||||
|
||||
return processed
|
||||
|
@ -461,204 +461,3 @@
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_pipeline_language_used_instead_of_conversation_language
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
}),
|
||||
'type': <PipelineEventType.RUN_START: 'run-start'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'device_id': None,
|
||||
'engine': 'conversation.home_assistant',
|
||||
'intent_input': 'test input',
|
||||
'language': 'en',
|
||||
'prefer_local_intents': False,
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_START: 'intent-start'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'intent_output': dict({
|
||||
'continue_conversation': False,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
'data': dict({
|
||||
'failed': list([
|
||||
]),
|
||||
'success': list([
|
||||
]),
|
||||
'targets': list([
|
||||
]),
|
||||
}),
|
||||
'language': 'en',
|
||||
'response_type': 'action_done',
|
||||
'speech': dict({
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
'processed_locally': True,
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_END: 'intent-end'>,
|
||||
}),
|
||||
dict({
|
||||
'data': None,
|
||||
'type': <PipelineEventType.RUN_END: 'run-end'>,
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_stt_language_used_instead_of_conversation_language
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
}),
|
||||
'type': <PipelineEventType.RUN_START: 'run-start'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'device_id': None,
|
||||
'engine': 'conversation.home_assistant',
|
||||
'intent_input': 'test input',
|
||||
'language': 'en-US',
|
||||
'prefer_local_intents': False,
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_START: 'intent-start'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'intent_output': dict({
|
||||
'continue_conversation': False,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
'data': dict({
|
||||
'failed': list([
|
||||
]),
|
||||
'success': list([
|
||||
]),
|
||||
'targets': list([
|
||||
]),
|
||||
}),
|
||||
'language': 'en',
|
||||
'response_type': 'action_done',
|
||||
'speech': dict({
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
'processed_locally': True,
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_END: 'intent-end'>,
|
||||
}),
|
||||
dict({
|
||||
'data': None,
|
||||
'type': <PipelineEventType.RUN_END: 'run-end'>,
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_tts_language_used_instead_of_conversation_language
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
}),
|
||||
'type': <PipelineEventType.RUN_START: 'run-start'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'device_id': None,
|
||||
'engine': 'conversation.home_assistant',
|
||||
'intent_input': 'test input',
|
||||
'language': 'en-us',
|
||||
'prefer_local_intents': False,
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_START: 'intent-start'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'intent_output': dict({
|
||||
'continue_conversation': False,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
'data': dict({
|
||||
'failed': list([
|
||||
]),
|
||||
'success': list([
|
||||
]),
|
||||
'targets': list([
|
||||
]),
|
||||
}),
|
||||
'language': 'en',
|
||||
'response_type': 'action_done',
|
||||
'speech': dict({
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
'processed_locally': True,
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_END: 'intent-end'>,
|
||||
}),
|
||||
dict({
|
||||
'data': None,
|
||||
'type': <PipelineEventType.RUN_END: 'run-end'>,
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_wake_word_detection_aborted
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'token': 'mocked-token.mp3',
|
||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.RUN_START: 'run-start'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'entity_id': 'wake_word.test',
|
||||
'metadata': dict({
|
||||
'bit_rate': <AudioBitRates.BITRATE_16: 16>,
|
||||
'channel': <AudioChannels.CHANNEL_MONO: 1>,
|
||||
'codec': <AudioCodecs.PCM: 'pcm'>,
|
||||
'format': <AudioFormats.WAV: 'wav'>,
|
||||
'sample_rate': <AudioSampleRates.SAMPLERATE_16000: 16000>,
|
||||
}),
|
||||
'timeout': 0,
|
||||
}),
|
||||
'type': <PipelineEventType.WAKE_WORD_START: 'wake_word-start'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'code': 'wake_word_detection_aborted',
|
||||
'message': '',
|
||||
}),
|
||||
'type': <PipelineEventType.ERROR: 'error'>,
|
||||
}),
|
||||
dict({
|
||||
'data': None,
|
||||
'type': <PipelineEventType.RUN_END: 'run-end'>,
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
|
202
tests/components/assist_pipeline/snapshots/test_pipeline.ambr
Normal file
202
tests/components/assist_pipeline/snapshots/test_pipeline.ambr
Normal file
@ -0,0 +1,202 @@
|
||||
# serializer version: 1
|
||||
# name: test_pipeline_language_used_instead_of_conversation_language
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
}),
|
||||
'type': <PipelineEventType.RUN_START: 'run-start'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'device_id': None,
|
||||
'engine': 'conversation.home_assistant',
|
||||
'intent_input': 'test input',
|
||||
'language': 'en',
|
||||
'prefer_local_intents': False,
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_START: 'intent-start'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'intent_output': dict({
|
||||
'continue_conversation': False,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
'data': dict({
|
||||
'failed': list([
|
||||
]),
|
||||
'success': list([
|
||||
]),
|
||||
'targets': list([
|
||||
]),
|
||||
}),
|
||||
'language': 'en',
|
||||
'response_type': 'action_done',
|
||||
'speech': dict({
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
'processed_locally': True,
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_END: 'intent-end'>,
|
||||
}),
|
||||
dict({
|
||||
'data': None,
|
||||
'type': <PipelineEventType.RUN_END: 'run-end'>,
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_stt_language_used_instead_of_conversation_language
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
}),
|
||||
'type': <PipelineEventType.RUN_START: 'run-start'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'device_id': None,
|
||||
'engine': 'conversation.home_assistant',
|
||||
'intent_input': 'test input',
|
||||
'language': 'en-US',
|
||||
'prefer_local_intents': False,
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_START: 'intent-start'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'intent_output': dict({
|
||||
'continue_conversation': False,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
'data': dict({
|
||||
'failed': list([
|
||||
]),
|
||||
'success': list([
|
||||
]),
|
||||
'targets': list([
|
||||
]),
|
||||
}),
|
||||
'language': 'en',
|
||||
'response_type': 'action_done',
|
||||
'speech': dict({
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
'processed_locally': True,
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_END: 'intent-end'>,
|
||||
}),
|
||||
dict({
|
||||
'data': None,
|
||||
'type': <PipelineEventType.RUN_END: 'run-end'>,
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_tts_language_used_instead_of_conversation_language
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
}),
|
||||
'type': <PipelineEventType.RUN_START: 'run-start'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'device_id': None,
|
||||
'engine': 'conversation.home_assistant',
|
||||
'intent_input': 'test input',
|
||||
'language': 'en-us',
|
||||
'prefer_local_intents': False,
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_START: 'intent-start'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'intent_output': dict({
|
||||
'continue_conversation': False,
|
||||
'conversation_id': <ANY>,
|
||||
'response': dict({
|
||||
'card': dict({
|
||||
}),
|
||||
'data': dict({
|
||||
'failed': list([
|
||||
]),
|
||||
'success': list([
|
||||
]),
|
||||
'targets': list([
|
||||
]),
|
||||
}),
|
||||
'language': 'en',
|
||||
'response_type': 'action_done',
|
||||
'speech': dict({
|
||||
}),
|
||||
}),
|
||||
}),
|
||||
'processed_locally': True,
|
||||
}),
|
||||
'type': <PipelineEventType.INTENT_END: 'intent-end'>,
|
||||
}),
|
||||
dict({
|
||||
'data': None,
|
||||
'type': <PipelineEventType.RUN_END: 'run-end'>,
|
||||
}),
|
||||
])
|
||||
# ---
|
||||
# name: test_wake_word_detection_aborted
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'tts_output': dict({
|
||||
'mime_type': 'audio/mpeg',
|
||||
'token': 'mocked-token.mp3',
|
||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||
}),
|
||||
}),
|
||||
'type': <PipelineEventType.RUN_START: 'run-start'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'entity_id': 'wake_word.test',
|
||||
'metadata': dict({
|
||||
'bit_rate': <AudioBitRates.BITRATE_16: 16>,
|
||||
'channel': <AudioChannels.CHANNEL_MONO: 1>,
|
||||
'codec': <AudioCodecs.PCM: 'pcm'>,
|
||||
'format': <AudioFormats.WAV: 'wav'>,
|
||||
'sample_rate': <AudioSampleRates.SAMPLERATE_16000: 16000>,
|
||||
}),
|
||||
'timeout': 0,
|
||||
}),
|
||||
'type': <PipelineEventType.WAKE_WORD_START: 'wake_word-start'>,
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'code': 'wake_word_detection_aborted',
|
||||
'message': '',
|
||||
}),
|
||||
'type': <PipelineEventType.ERROR: 'error'>,
|
||||
}),
|
||||
dict({
|
||||
'data': None,
|
||||
'type': <PipelineEventType.RUN_END: 'run-end'>,
|
||||
}),
|
||||
])
|
||||
# ---
|
@ -2,44 +2,35 @@
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Generator
|
||||
from dataclasses import asdict
|
||||
import itertools as it
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
from unittest.mock import ANY, Mock, patch
|
||||
from unittest.mock import Mock, patch
|
||||
import wave
|
||||
|
||||
import hass_nabucasa
|
||||
import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from homeassistant.components import (
|
||||
assist_pipeline,
|
||||
conversation,
|
||||
media_source,
|
||||
stt,
|
||||
tts,
|
||||
)
|
||||
from homeassistant.components import assist_pipeline, stt
|
||||
from homeassistant.components.assist_pipeline.const import (
|
||||
BYTES_PER_CHUNK,
|
||||
CONF_DEBUG_RECORDING_DIR,
|
||||
DOMAIN,
|
||||
)
|
||||
from homeassistant.const import MATCH_ALL
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.helpers import chat_session, intent
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from . import process_events
|
||||
from .conftest import (
|
||||
BYTES_ONE_SECOND,
|
||||
MockSTTProvider,
|
||||
MockSTTProviderEntity,
|
||||
MockTTSProvider,
|
||||
MockWakeWordEntity,
|
||||
make_10ms_chunk,
|
||||
)
|
||||
|
||||
from tests.typing import ClientSessionGenerator, WebSocketGenerator
|
||||
from tests.typing import WebSocketGenerator
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@ -58,19 +49,6 @@ def mock_tts_token() -> Generator[None]:
|
||||
yield
|
||||
|
||||
|
||||
def process_events(events: list[assist_pipeline.PipelineEvent]) -> list[dict]:
|
||||
"""Process events to remove dynamic values."""
|
||||
processed = []
|
||||
for event in events:
|
||||
as_dict = asdict(event)
|
||||
as_dict.pop("timestamp")
|
||||
if as_dict["type"] == assist_pipeline.PipelineEventType.RUN_START:
|
||||
as_dict["data"]["pipeline"] = ANY
|
||||
processed.append(as_dict)
|
||||
|
||||
return processed
|
||||
|
||||
|
||||
async def test_pipeline_from_audio_stream_auto(
|
||||
hass: HomeAssistant,
|
||||
mock_stt_provider_entity: MockSTTProviderEntity,
|
||||
@ -677,823 +655,6 @@ async def test_pipeline_saved_audio_empty_queue(
|
||||
)
|
||||
|
||||
|
||||
async def test_wake_word_detection_aborted(
|
||||
hass: HomeAssistant,
|
||||
mock_stt_provider: MockSTTProvider,
|
||||
mock_wake_word_provider_entity: MockWakeWordEntity,
|
||||
init_components,
|
||||
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||
mock_chat_session: chat_session.ChatSession,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test creating a pipeline from an audio stream with wake word."""
|
||||
|
||||
events: list[assist_pipeline.PipelineEvent] = []
|
||||
|
||||
async def audio_data():
|
||||
yield make_10ms_chunk(b"silence!")
|
||||
yield make_10ms_chunk(b"wake word!")
|
||||
yield make_10ms_chunk(b"part1")
|
||||
yield make_10ms_chunk(b"part2")
|
||||
yield b""
|
||||
|
||||
pipeline_store = pipeline_data.pipeline_store
|
||||
pipeline_id = pipeline_store.async_get_preferred_item()
|
||||
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
session=mock_chat_session,
|
||||
device_id=None,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language="",
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
stt_stream=audio_data(),
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
pipeline=pipeline,
|
||||
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
|
||||
end_stage=assist_pipeline.PipelineStage.TTS,
|
||||
event_callback=events.append,
|
||||
tts_audio_output=None,
|
||||
wake_word_settings=assist_pipeline.WakeWordSettings(
|
||||
audio_seconds_to_buffer=1.5
|
||||
),
|
||||
audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False),
|
||||
),
|
||||
)
|
||||
await pipeline_input.validate()
|
||||
|
||||
updates = pipeline.to_json()
|
||||
updates.pop("id")
|
||||
await pipeline_store.async_update_item(
|
||||
pipeline_id,
|
||||
updates,
|
||||
)
|
||||
await pipeline_input.execute()
|
||||
|
||||
assert process_events(events) == snapshot
|
||||
|
||||
|
||||
def test_pipeline_run_equality(hass: HomeAssistant, init_components) -> None:
|
||||
"""Test that pipeline run equality uses unique id."""
|
||||
|
||||
def event_callback(event):
|
||||
pass
|
||||
|
||||
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass)
|
||||
run_1 = assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
pipeline=pipeline,
|
||||
start_stage=assist_pipeline.PipelineStage.STT,
|
||||
end_stage=assist_pipeline.PipelineStage.TTS,
|
||||
event_callback=event_callback,
|
||||
)
|
||||
run_2 = assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
pipeline=pipeline,
|
||||
start_stage=assist_pipeline.PipelineStage.STT,
|
||||
end_stage=assist_pipeline.PipelineStage.TTS,
|
||||
event_callback=event_callback,
|
||||
)
|
||||
|
||||
assert run_1 == run_1 # noqa: PLR0124
|
||||
assert run_1 != run_2
|
||||
assert run_1 != 1234
|
||||
|
||||
|
||||
async def test_tts_audio_output(
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
mock_tts_provider: MockTTSProvider,
|
||||
init_components,
|
||||
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||
mock_chat_session: chat_session.ChatSession,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test using tts_audio_output with wav sets options correctly."""
|
||||
client = await hass_client()
|
||||
assert await async_setup_component(hass, media_source.DOMAIN, {})
|
||||
|
||||
events: list[assist_pipeline.PipelineEvent] = []
|
||||
|
||||
pipeline_store = pipeline_data.pipeline_store
|
||||
pipeline_id = pipeline_store.async_get_preferred_item()
|
||||
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
tts_input="This is a test.",
|
||||
session=mock_chat_session,
|
||||
device_id=None,
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
pipeline=pipeline,
|
||||
start_stage=assist_pipeline.PipelineStage.TTS,
|
||||
end_stage=assist_pipeline.PipelineStage.TTS,
|
||||
event_callback=events.append,
|
||||
tts_audio_output="wav",
|
||||
),
|
||||
)
|
||||
await pipeline_input.validate()
|
||||
|
||||
# Verify TTS audio settings
|
||||
assert pipeline_input.run.tts_stream.options is not None
|
||||
assert pipeline_input.run.tts_stream.options.get(tts.ATTR_PREFERRED_FORMAT) == "wav"
|
||||
assert (
|
||||
pipeline_input.run.tts_stream.options.get(tts.ATTR_PREFERRED_SAMPLE_RATE)
|
||||
== 16000
|
||||
)
|
||||
assert (
|
||||
pipeline_input.run.tts_stream.options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS)
|
||||
== 1
|
||||
)
|
||||
|
||||
with patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio:
|
||||
await pipeline_input.execute()
|
||||
|
||||
for event in events:
|
||||
if event.type == assist_pipeline.PipelineEventType.TTS_END:
|
||||
# We must fetch the media URL to trigger the TTS
|
||||
assert event.data
|
||||
await client.get(event.data["tts_output"]["url"])
|
||||
|
||||
# Ensure that no unsupported options were passed in
|
||||
assert mock_get_tts_audio.called
|
||||
options = mock_get_tts_audio.call_args_list[0].kwargs["options"]
|
||||
extra_options = set(options).difference(mock_tts_provider.supported_options)
|
||||
assert len(extra_options) == 0, extra_options
|
||||
|
||||
|
||||
async def test_tts_wav_preferred_format(
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
mock_tts_provider: MockTTSProvider,
|
||||
init_components,
|
||||
mock_chat_session: chat_session.ChatSession,
|
||||
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||
) -> None:
|
||||
"""Test that preferred format options are given to the TTS system if supported."""
|
||||
client = await hass_client()
|
||||
assert await async_setup_component(hass, media_source.DOMAIN, {})
|
||||
|
||||
events: list[assist_pipeline.PipelineEvent] = []
|
||||
|
||||
pipeline_store = pipeline_data.pipeline_store
|
||||
pipeline_id = pipeline_store.async_get_preferred_item()
|
||||
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
tts_input="This is a test.",
|
||||
session=mock_chat_session,
|
||||
device_id=None,
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
pipeline=pipeline,
|
||||
start_stage=assist_pipeline.PipelineStage.TTS,
|
||||
end_stage=assist_pipeline.PipelineStage.TTS,
|
||||
event_callback=events.append,
|
||||
tts_audio_output="wav",
|
||||
),
|
||||
)
|
||||
await pipeline_input.validate()
|
||||
|
||||
# Make the TTS provider support preferred format options
|
||||
supported_options = list(mock_tts_provider.supported_options or [])
|
||||
supported_options.extend(
|
||||
[
|
||||
tts.ATTR_PREFERRED_FORMAT,
|
||||
tts.ATTR_PREFERRED_SAMPLE_RATE,
|
||||
tts.ATTR_PREFERRED_SAMPLE_CHANNELS,
|
||||
tts.ATTR_PREFERRED_SAMPLE_BYTES,
|
||||
]
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(mock_tts_provider, "_supported_options", supported_options),
|
||||
patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio,
|
||||
):
|
||||
await pipeline_input.execute()
|
||||
|
||||
for event in events:
|
||||
if event.type == assist_pipeline.PipelineEventType.TTS_END:
|
||||
# We must fetch the media URL to trigger the TTS
|
||||
assert event.data
|
||||
await client.get(event.data["tts_output"]["url"])
|
||||
|
||||
assert mock_get_tts_audio.called
|
||||
options = mock_get_tts_audio.call_args_list[0].kwargs["options"]
|
||||
|
||||
# We should have received preferred format options in get_tts_audio
|
||||
assert options.get(tts.ATTR_PREFERRED_FORMAT) == "wav"
|
||||
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_RATE)) == 16000
|
||||
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS)) == 1
|
||||
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_BYTES)) == 2
|
||||
|
||||
|
||||
async def test_tts_dict_preferred_format(
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
mock_tts_provider: MockTTSProvider,
|
||||
init_components,
|
||||
mock_chat_session: chat_session.ChatSession,
|
||||
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||
) -> None:
|
||||
"""Test that preferred format options are given to the TTS system if supported."""
|
||||
client = await hass_client()
|
||||
assert await async_setup_component(hass, media_source.DOMAIN, {})
|
||||
|
||||
events: list[assist_pipeline.PipelineEvent] = []
|
||||
|
||||
pipeline_store = pipeline_data.pipeline_store
|
||||
pipeline_id = pipeline_store.async_get_preferred_item()
|
||||
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
tts_input="This is a test.",
|
||||
session=mock_chat_session,
|
||||
device_id=None,
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
pipeline=pipeline,
|
||||
start_stage=assist_pipeline.PipelineStage.TTS,
|
||||
end_stage=assist_pipeline.PipelineStage.TTS,
|
||||
event_callback=events.append,
|
||||
tts_audio_output={
|
||||
tts.ATTR_PREFERRED_FORMAT: "flac",
|
||||
tts.ATTR_PREFERRED_SAMPLE_RATE: 48000,
|
||||
tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 2,
|
||||
tts.ATTR_PREFERRED_SAMPLE_BYTES: 2,
|
||||
},
|
||||
),
|
||||
)
|
||||
await pipeline_input.validate()
|
||||
|
||||
# Make the TTS provider support preferred format options
|
||||
supported_options = list(mock_tts_provider.supported_options or [])
|
||||
supported_options.extend(
|
||||
[
|
||||
tts.ATTR_PREFERRED_FORMAT,
|
||||
tts.ATTR_PREFERRED_SAMPLE_RATE,
|
||||
tts.ATTR_PREFERRED_SAMPLE_CHANNELS,
|
||||
tts.ATTR_PREFERRED_SAMPLE_BYTES,
|
||||
]
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(mock_tts_provider, "_supported_options", supported_options),
|
||||
patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio,
|
||||
):
|
||||
await pipeline_input.execute()
|
||||
|
||||
for event in events:
|
||||
if event.type == assist_pipeline.PipelineEventType.TTS_END:
|
||||
# We must fetch the media URL to trigger the TTS
|
||||
assert event.data
|
||||
await client.get(event.data["tts_output"]["url"])
|
||||
|
||||
assert mock_get_tts_audio.called
|
||||
options = mock_get_tts_audio.call_args_list[0].kwargs["options"]
|
||||
|
||||
# We should have received preferred format options in get_tts_audio
|
||||
assert options.get(tts.ATTR_PREFERRED_FORMAT) == "flac"
|
||||
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_RATE)) == 48000
|
||||
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS)) == 2
|
||||
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_BYTES)) == 2
|
||||
|
||||
|
||||
async def test_sentence_trigger_overrides_conversation_agent(
|
||||
hass: HomeAssistant,
|
||||
init_components,
|
||||
mock_chat_session: chat_session.ChatSession,
|
||||
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||
) -> None:
|
||||
"""Test that sentence triggers are checked before a non-default conversation agent."""
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
"automation",
|
||||
{
|
||||
"automation": {
|
||||
"trigger": {
|
||||
"platform": "conversation",
|
||||
"command": [
|
||||
"test trigger sentence",
|
||||
],
|
||||
},
|
||||
"action": {
|
||||
"set_conversation_response": "test trigger response",
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
events: list[assist_pipeline.PipelineEvent] = []
|
||||
|
||||
pipeline_store = pipeline_data.pipeline_store
|
||||
pipeline_id = pipeline_store.async_get_preferred_item()
|
||||
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
intent_input="test trigger sentence",
|
||||
session=mock_chat_session,
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
pipeline=pipeline,
|
||||
start_stage=assist_pipeline.PipelineStage.INTENT,
|
||||
end_stage=assist_pipeline.PipelineStage.INTENT,
|
||||
event_callback=events.append,
|
||||
intent_agent="test-agent", # not the default agent
|
||||
),
|
||||
)
|
||||
|
||||
# Ensure prepare succeeds
|
||||
with patch(
|
||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
|
||||
return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"),
|
||||
):
|
||||
await pipeline_input.validate()
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse"
|
||||
) as mock_async_converse:
|
||||
await pipeline_input.execute()
|
||||
|
||||
# Sentence trigger should have been handled
|
||||
mock_async_converse.assert_not_called()
|
||||
|
||||
# Verify sentence trigger response
|
||||
intent_end_event = next(
|
||||
(
|
||||
e
|
||||
for e in events
|
||||
if e.type == assist_pipeline.PipelineEventType.INTENT_END
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert (intent_end_event is not None) and intent_end_event.data
|
||||
assert (
|
||||
intent_end_event.data["intent_output"]["response"]["speech"]["plain"][
|
||||
"speech"
|
||||
]
|
||||
== "test trigger response"
|
||||
)
|
||||
|
||||
|
||||
async def test_prefer_local_intents(
|
||||
hass: HomeAssistant,
|
||||
init_components,
|
||||
mock_chat_session: chat_session.ChatSession,
|
||||
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||
) -> None:
|
||||
"""Test that the default agent is checked first when local intents are preferred."""
|
||||
events: list[assist_pipeline.PipelineEvent] = []
|
||||
|
||||
# Reuse custom sentences in test config
|
||||
class OrderBeerIntentHandler(intent.IntentHandler):
|
||||
intent_type = "OrderBeer"
|
||||
|
||||
async def async_handle(
|
||||
self, intent_obj: intent.Intent
|
||||
) -> intent.IntentResponse:
|
||||
response = intent_obj.create_response()
|
||||
response.async_set_speech("Order confirmed")
|
||||
return response
|
||||
|
||||
handler = OrderBeerIntentHandler()
|
||||
intent.async_register(hass, handler)
|
||||
|
||||
# Fake a test agent and prefer local intents
|
||||
pipeline_store = pipeline_data.pipeline_store
|
||||
pipeline_id = pipeline_store.async_get_preferred_item()
|
||||
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||
await assist_pipeline.pipeline.async_update_pipeline(
|
||||
hass, pipeline, conversation_engine="test-agent", prefer_local_intents=True
|
||||
)
|
||||
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
intent_input="I'd like to order a stout please",
|
||||
session=mock_chat_session,
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
pipeline=pipeline,
|
||||
start_stage=assist_pipeline.PipelineStage.INTENT,
|
||||
end_stage=assist_pipeline.PipelineStage.INTENT,
|
||||
event_callback=events.append,
|
||||
),
|
||||
)
|
||||
|
||||
# Ensure prepare succeeds
|
||||
with patch(
|
||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
|
||||
return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"),
|
||||
):
|
||||
await pipeline_input.validate()
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse"
|
||||
) as mock_async_converse:
|
||||
await pipeline_input.execute()
|
||||
|
||||
# Test agent should not have been called
|
||||
mock_async_converse.assert_not_called()
|
||||
|
||||
# Verify local intent response
|
||||
intent_end_event = next(
|
||||
(
|
||||
e
|
||||
for e in events
|
||||
if e.type == assist_pipeline.PipelineEventType.INTENT_END
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert (intent_end_event is not None) and intent_end_event.data
|
||||
assert (
|
||||
intent_end_event.data["intent_output"]["response"]["speech"]["plain"][
|
||||
"speech"
|
||||
]
|
||||
== "Order confirmed"
|
||||
)
|
||||
|
||||
|
||||
async def test_intent_continue_conversation(
|
||||
hass: HomeAssistant,
|
||||
init_components,
|
||||
mock_chat_session: chat_session.ChatSession,
|
||||
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||
) -> None:
|
||||
"""Test that a conversation agent flagging continue conversation gets response."""
|
||||
events: list[assist_pipeline.PipelineEvent] = []
|
||||
|
||||
# Fake a test agent and prefer local intents
|
||||
pipeline_store = pipeline_data.pipeline_store
|
||||
pipeline_id = pipeline_store.async_get_preferred_item()
|
||||
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||
await assist_pipeline.pipeline.async_update_pipeline(
|
||||
hass, pipeline, conversation_engine="test-agent"
|
||||
)
|
||||
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
intent_input="Set a timer",
|
||||
session=mock_chat_session,
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
pipeline=pipeline,
|
||||
start_stage=assist_pipeline.PipelineStage.INTENT,
|
||||
end_stage=assist_pipeline.PipelineStage.INTENT,
|
||||
event_callback=events.append,
|
||||
),
|
||||
)
|
||||
|
||||
# Ensure prepare succeeds
|
||||
with patch(
|
||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
|
||||
return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"),
|
||||
):
|
||||
await pipeline_input.validate()
|
||||
|
||||
response = intent.IntentResponse("en")
|
||||
response.async_set_speech("For how long?")
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
|
||||
return_value=conversation.ConversationResult(
|
||||
response=response,
|
||||
conversation_id=mock_chat_session.conversation_id,
|
||||
continue_conversation=True,
|
||||
),
|
||||
) as mock_async_converse:
|
||||
await pipeline_input.execute()
|
||||
|
||||
mock_async_converse.assert_called()
|
||||
|
||||
results = [
|
||||
event.data
|
||||
for event in events
|
||||
if event.type
|
||||
in (
|
||||
assist_pipeline.PipelineEventType.INTENT_START,
|
||||
assist_pipeline.PipelineEventType.INTENT_END,
|
||||
)
|
||||
]
|
||||
assert results[1]["intent_output"]["continue_conversation"] is True
|
||||
|
||||
# Change conversation agent to default one and register sentence trigger that should not be called
|
||||
await assist_pipeline.pipeline.async_update_pipeline(
|
||||
hass, pipeline, conversation_engine=None
|
||||
)
|
||||
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
"automation",
|
||||
{
|
||||
"automation": {
|
||||
"trigger": {
|
||||
"platform": "conversation",
|
||||
"command": ["Hello"],
|
||||
},
|
||||
"action": {
|
||||
"set_conversation_response": "test trigger response",
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Because we did continue conversation, it should respond to the test agent again.
|
||||
events.clear()
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
intent_input="Hello",
|
||||
session=mock_chat_session,
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
pipeline=pipeline,
|
||||
start_stage=assist_pipeline.PipelineStage.INTENT,
|
||||
end_stage=assist_pipeline.PipelineStage.INTENT,
|
||||
event_callback=events.append,
|
||||
),
|
||||
)
|
||||
|
||||
# Ensure prepare succeeds
|
||||
with patch(
|
||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
|
||||
return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"),
|
||||
) as mock_prepare:
|
||||
await pipeline_input.validate()
|
||||
|
||||
# It requested test agent even if that was not default agent.
|
||||
assert mock_prepare.mock_calls[0][1][1] == "test-agent"
|
||||
|
||||
response = intent.IntentResponse("en")
|
||||
response.async_set_speech("Timer set for 20 minutes")
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
|
||||
return_value=conversation.ConversationResult(
|
||||
response=response,
|
||||
conversation_id=mock_chat_session.conversation_id,
|
||||
),
|
||||
) as mock_async_converse:
|
||||
await pipeline_input.execute()
|
||||
|
||||
mock_async_converse.assert_called()
|
||||
|
||||
# Snapshot will show it was still handled by the test agent and not default agent
|
||||
results = [
|
||||
event.data
|
||||
for event in events
|
||||
if event.type
|
||||
in (
|
||||
assist_pipeline.PipelineEventType.INTENT_START,
|
||||
assist_pipeline.PipelineEventType.INTENT_END,
|
||||
)
|
||||
]
|
||||
assert results[0]["engine"] == "test-agent"
|
||||
assert results[1]["intent_output"]["continue_conversation"] is False
|
||||
|
||||
|
||||
async def test_stt_language_used_instead_of_conversation_language(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
init_components,
|
||||
mock_chat_session: chat_session.ChatSession,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test that the STT language is used first when the conversation language is '*' (all languages)."""
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
events: list[assist_pipeline.PipelineEvent] = []
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline/create",
|
||||
"conversation_engine": "homeassistant",
|
||||
"conversation_language": MATCH_ALL,
|
||||
"language": "en",
|
||||
"name": "test_name",
|
||||
"stt_engine": "test",
|
||||
"stt_language": "en-US",
|
||||
"tts_engine": "test",
|
||||
"tts_language": "en-US",
|
||||
"tts_voice": "Arnold Schwarzenegger",
|
||||
"wake_word_entity": None,
|
||||
"wake_word_id": None,
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
pipeline_id = msg["result"]["id"]
|
||||
pipeline = assist_pipeline.async_get_pipeline(hass, pipeline_id)
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
intent_input="test input",
|
||||
session=mock_chat_session,
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
pipeline=pipeline,
|
||||
start_stage=assist_pipeline.PipelineStage.INTENT,
|
||||
end_stage=assist_pipeline.PipelineStage.INTENT,
|
||||
event_callback=events.append,
|
||||
),
|
||||
)
|
||||
await pipeline_input.validate()
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
|
||||
return_value=conversation.ConversationResult(
|
||||
intent.IntentResponse(pipeline.language)
|
||||
),
|
||||
) as mock_async_converse:
|
||||
await pipeline_input.execute()
|
||||
|
||||
# Check intent start event
|
||||
assert process_events(events) == snapshot
|
||||
intent_start: assist_pipeline.PipelineEvent | None = None
|
||||
for event in events:
|
||||
if event.type == assist_pipeline.PipelineEventType.INTENT_START:
|
||||
intent_start = event
|
||||
break
|
||||
|
||||
assert intent_start is not None
|
||||
|
||||
# STT language (en-US) should be used instead of '*'
|
||||
assert intent_start.data.get("language") == pipeline.stt_language
|
||||
|
||||
# Check input to async_converse
|
||||
mock_async_converse.assert_called_once()
|
||||
assert (
|
||||
mock_async_converse.call_args_list[0].kwargs.get("language")
|
||||
== pipeline.stt_language
|
||||
)
|
||||
|
||||
|
||||
async def test_tts_language_used_instead_of_conversation_language(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
init_components,
|
||||
mock_chat_session: chat_session.ChatSession,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test that the TTS language is used after STT when the conversation language is '*' (all languages)."""
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
events: list[assist_pipeline.PipelineEvent] = []
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline/create",
|
||||
"conversation_engine": "homeassistant",
|
||||
"conversation_language": MATCH_ALL,
|
||||
"language": "en",
|
||||
"name": "test_name",
|
||||
"stt_engine": None,
|
||||
"stt_language": None,
|
||||
"tts_engine": None,
|
||||
"tts_language": "en-us",
|
||||
"tts_voice": "Arnold Schwarzenegger",
|
||||
"wake_word_entity": None,
|
||||
"wake_word_id": None,
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
pipeline_id = msg["result"]["id"]
|
||||
pipeline = assist_pipeline.async_get_pipeline(hass, pipeline_id)
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
intent_input="test input",
|
||||
session=mock_chat_session,
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
pipeline=pipeline,
|
||||
start_stage=assist_pipeline.PipelineStage.INTENT,
|
||||
end_stage=assist_pipeline.PipelineStage.INTENT,
|
||||
event_callback=events.append,
|
||||
),
|
||||
)
|
||||
await pipeline_input.validate()
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
|
||||
return_value=conversation.ConversationResult(
|
||||
intent.IntentResponse(pipeline.language)
|
||||
),
|
||||
) as mock_async_converse:
|
||||
await pipeline_input.execute()
|
||||
|
||||
# Check intent start event
|
||||
assert process_events(events) == snapshot
|
||||
intent_start: assist_pipeline.PipelineEvent | None = None
|
||||
for event in events:
|
||||
if event.type == assist_pipeline.PipelineEventType.INTENT_START:
|
||||
intent_start = event
|
||||
break
|
||||
|
||||
assert intent_start is not None
|
||||
|
||||
# STT language (en-US) should be used instead of '*'
|
||||
assert intent_start.data.get("language") == pipeline.tts_language
|
||||
|
||||
# Check input to async_converse
|
||||
mock_async_converse.assert_called_once()
|
||||
assert (
|
||||
mock_async_converse.call_args_list[0].kwargs.get("language")
|
||||
== pipeline.tts_language
|
||||
)
|
||||
|
||||
|
||||
async def test_pipeline_language_used_instead_of_conversation_language(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
init_components,
|
||||
mock_chat_session: chat_session.ChatSession,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test that the pipeline language is used last when the conversation language is '*' (all languages)."""
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
events: list[assist_pipeline.PipelineEvent] = []
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline/create",
|
||||
"conversation_engine": "homeassistant",
|
||||
"conversation_language": MATCH_ALL,
|
||||
"language": "en",
|
||||
"name": "test_name",
|
||||
"stt_engine": None,
|
||||
"stt_language": None,
|
||||
"tts_engine": None,
|
||||
"tts_language": None,
|
||||
"tts_voice": None,
|
||||
"wake_word_entity": None,
|
||||
"wake_word_id": None,
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
pipeline_id = msg["result"]["id"]
|
||||
pipeline = assist_pipeline.async_get_pipeline(hass, pipeline_id)
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
intent_input="test input",
|
||||
session=mock_chat_session,
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
pipeline=pipeline,
|
||||
start_stage=assist_pipeline.PipelineStage.INTENT,
|
||||
end_stage=assist_pipeline.PipelineStage.INTENT,
|
||||
event_callback=events.append,
|
||||
),
|
||||
)
|
||||
await pipeline_input.validate()
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
|
||||
return_value=conversation.ConversationResult(
|
||||
intent.IntentResponse(pipeline.language)
|
||||
),
|
||||
) as mock_async_converse:
|
||||
await pipeline_input.execute()
|
||||
|
||||
# Check intent start event
|
||||
assert process_events(events) == snapshot
|
||||
intent_start: assist_pipeline.PipelineEvent | None = None
|
||||
for event in events:
|
||||
if event.type == assist_pipeline.PipelineEventType.INTENT_START:
|
||||
intent_start = event
|
||||
break
|
||||
|
||||
assert intent_start is not None
|
||||
|
||||
# STT language (en-US) should be used instead of '*'
|
||||
assert intent_start.data.get("language") == pipeline.language
|
||||
|
||||
# Check input to async_converse
|
||||
mock_async_converse.assert_called_once()
|
||||
assert (
|
||||
mock_async_converse.call_args_list[0].kwargs.get("language")
|
||||
== pipeline.language
|
||||
)
|
||||
|
||||
|
||||
async def test_pipeline_from_audio_stream_with_cloud_auth_fail(
|
||||
hass: HomeAssistant,
|
||||
mock_stt_provider_entity: MockSTTProviderEntity,
|
||||
|
@ -1,13 +1,20 @@
|
||||
"""Websocket tests for Voice Assistant integration."""
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from collections.abc import AsyncGenerator, Generator
|
||||
from typing import Any
|
||||
from unittest.mock import ANY, patch
|
||||
from unittest.mock import ANY, Mock, patch
|
||||
|
||||
from hassil.recognize import Intent, IntentData, RecognizeResult
|
||||
import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
|
||||
from homeassistant.components import conversation
|
||||
from homeassistant.components import (
|
||||
assist_pipeline,
|
||||
conversation,
|
||||
media_source,
|
||||
stt,
|
||||
tts,
|
||||
)
|
||||
from homeassistant.components.assist_pipeline.const import DOMAIN
|
||||
from homeassistant.components.assist_pipeline.pipeline import (
|
||||
STORAGE_KEY,
|
||||
@ -24,14 +31,22 @@ from homeassistant.components.assist_pipeline.pipeline import (
|
||||
async_migrate_engine,
|
||||
async_update_pipeline,
|
||||
)
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import intent
|
||||
from homeassistant.const import MATCH_ALL
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.helpers import chat_session, intent
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from . import MANY_LANGUAGES
|
||||
from .conftest import MockSTTProviderEntity, MockTTSProvider
|
||||
from . import MANY_LANGUAGES, process_events
|
||||
from .conftest import (
|
||||
MockSTTProvider,
|
||||
MockSTTProviderEntity,
|
||||
MockTTSProvider,
|
||||
MockWakeWordEntity,
|
||||
make_10ms_chunk,
|
||||
)
|
||||
|
||||
from tests.common import flush_store
|
||||
from tests.typing import ClientSessionGenerator, WebSocketGenerator
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@ -119,6 +134,22 @@ async def test_load_pipelines(hass: HomeAssistant) -> None:
|
||||
assert store1.async_get_preferred_item() == store2.async_get_preferred_item()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_chat_session_id() -> Generator[Mock]:
|
||||
"""Mock the conversation ID of chat sessions."""
|
||||
with patch(
|
||||
"homeassistant.helpers.chat_session.ulid_now", return_value="mock-ulid"
|
||||
) as mock_ulid_now:
|
||||
yield mock_ulid_now
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_tts_token() -> Generator[None]:
|
||||
"""Mock the TTS token for URLs."""
|
||||
with patch("secrets.token_urlsafe", return_value="mocked-token"):
|
||||
yield
|
||||
|
||||
|
||||
async def test_loading_pipelines_from_storage(
|
||||
hass: HomeAssistant, hass_storage: dict[str, Any]
|
||||
) -> None:
|
||||
@ -697,3 +728,820 @@ def test_fallback_intent_filter() -> None:
|
||||
)
|
||||
is False
|
||||
)
|
||||
|
||||
|
||||
async def test_wake_word_detection_aborted(
|
||||
hass: HomeAssistant,
|
||||
mock_stt_provider: MockSTTProvider,
|
||||
mock_wake_word_provider_entity: MockWakeWordEntity,
|
||||
init_components,
|
||||
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||
mock_chat_session: chat_session.ChatSession,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test wake word stream is first detected, then aborted."""
|
||||
|
||||
events: list[assist_pipeline.PipelineEvent] = []
|
||||
|
||||
async def audio_data():
|
||||
yield make_10ms_chunk(b"silence!")
|
||||
yield make_10ms_chunk(b"wake word!")
|
||||
yield make_10ms_chunk(b"part1")
|
||||
yield make_10ms_chunk(b"part2")
|
||||
yield b""
|
||||
|
||||
pipeline_store = pipeline_data.pipeline_store
|
||||
pipeline_id = pipeline_store.async_get_preferred_item()
|
||||
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
session=mock_chat_session,
|
||||
device_id=None,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language="",
|
||||
format=stt.AudioFormats.WAV,
|
||||
codec=stt.AudioCodecs.PCM,
|
||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||
),
|
||||
stt_stream=audio_data(),
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
pipeline=pipeline,
|
||||
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
|
||||
end_stage=assist_pipeline.PipelineStage.TTS,
|
||||
event_callback=events.append,
|
||||
tts_audio_output=None,
|
||||
wake_word_settings=assist_pipeline.WakeWordSettings(
|
||||
audio_seconds_to_buffer=1.5
|
||||
),
|
||||
audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False),
|
||||
),
|
||||
)
|
||||
await pipeline_input.validate()
|
||||
|
||||
updates = pipeline.to_json()
|
||||
updates.pop("id")
|
||||
await pipeline_store.async_update_item(
|
||||
pipeline_id,
|
||||
updates,
|
||||
)
|
||||
await pipeline_input.execute()
|
||||
|
||||
assert process_events(events) == snapshot
|
||||
|
||||
|
||||
def test_pipeline_run_equality(hass: HomeAssistant, init_components) -> None:
|
||||
"""Test that pipeline run equality uses unique id."""
|
||||
|
||||
def event_callback(event):
|
||||
pass
|
||||
|
||||
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass)
|
||||
run_1 = assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
pipeline=pipeline,
|
||||
start_stage=assist_pipeline.PipelineStage.STT,
|
||||
end_stage=assist_pipeline.PipelineStage.TTS,
|
||||
event_callback=event_callback,
|
||||
)
|
||||
run_2 = assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
pipeline=pipeline,
|
||||
start_stage=assist_pipeline.PipelineStage.STT,
|
||||
end_stage=assist_pipeline.PipelineStage.TTS,
|
||||
event_callback=event_callback,
|
||||
)
|
||||
|
||||
assert run_1 == run_1 # noqa: PLR0124
|
||||
assert run_1 != run_2
|
||||
assert run_1 != 1234
|
||||
|
||||
|
||||
async def test_tts_audio_output(
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
mock_tts_provider: MockTTSProvider,
|
||||
init_components,
|
||||
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||
mock_chat_session: chat_session.ChatSession,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test using tts_audio_output with wav sets options correctly."""
|
||||
client = await hass_client()
|
||||
assert await async_setup_component(hass, media_source.DOMAIN, {})
|
||||
|
||||
events: list[assist_pipeline.PipelineEvent] = []
|
||||
|
||||
pipeline_store = pipeline_data.pipeline_store
|
||||
pipeline_id = pipeline_store.async_get_preferred_item()
|
||||
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
tts_input="This is a test.",
|
||||
session=mock_chat_session,
|
||||
device_id=None,
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
pipeline=pipeline,
|
||||
start_stage=assist_pipeline.PipelineStage.TTS,
|
||||
end_stage=assist_pipeline.PipelineStage.TTS,
|
||||
event_callback=events.append,
|
||||
tts_audio_output="wav",
|
||||
),
|
||||
)
|
||||
await pipeline_input.validate()
|
||||
|
||||
# Verify TTS audio settings
|
||||
assert pipeline_input.run.tts_stream.options is not None
|
||||
assert pipeline_input.run.tts_stream.options.get(tts.ATTR_PREFERRED_FORMAT) == "wav"
|
||||
assert (
|
||||
pipeline_input.run.tts_stream.options.get(tts.ATTR_PREFERRED_SAMPLE_RATE)
|
||||
== 16000
|
||||
)
|
||||
assert (
|
||||
pipeline_input.run.tts_stream.options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS)
|
||||
== 1
|
||||
)
|
||||
|
||||
with patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio:
|
||||
await pipeline_input.execute()
|
||||
|
||||
for event in events:
|
||||
if event.type == assist_pipeline.PipelineEventType.TTS_END:
|
||||
# We must fetch the media URL to trigger the TTS
|
||||
assert event.data
|
||||
await client.get(event.data["tts_output"]["url"])
|
||||
|
||||
# Ensure that no unsupported options were passed in
|
||||
assert mock_get_tts_audio.called
|
||||
options = mock_get_tts_audio.call_args_list[0].kwargs["options"]
|
||||
extra_options = set(options).difference(mock_tts_provider.supported_options)
|
||||
assert len(extra_options) == 0, extra_options
|
||||
|
||||
|
||||
async def test_tts_wav_preferred_format(
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
mock_tts_provider: MockTTSProvider,
|
||||
init_components,
|
||||
mock_chat_session: chat_session.ChatSession,
|
||||
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||
) -> None:
|
||||
"""Test that preferred format options are given to the TTS system if supported."""
|
||||
client = await hass_client()
|
||||
assert await async_setup_component(hass, media_source.DOMAIN, {})
|
||||
|
||||
events: list[assist_pipeline.PipelineEvent] = []
|
||||
|
||||
pipeline_store = pipeline_data.pipeline_store
|
||||
pipeline_id = pipeline_store.async_get_preferred_item()
|
||||
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
tts_input="This is a test.",
|
||||
session=mock_chat_session,
|
||||
device_id=None,
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
pipeline=pipeline,
|
||||
start_stage=assist_pipeline.PipelineStage.TTS,
|
||||
end_stage=assist_pipeline.PipelineStage.TTS,
|
||||
event_callback=events.append,
|
||||
tts_audio_output="wav",
|
||||
),
|
||||
)
|
||||
await pipeline_input.validate()
|
||||
|
||||
# Make the TTS provider support preferred format options
|
||||
supported_options = list(mock_tts_provider.supported_options or [])
|
||||
supported_options.extend(
|
||||
[
|
||||
tts.ATTR_PREFERRED_FORMAT,
|
||||
tts.ATTR_PREFERRED_SAMPLE_RATE,
|
||||
tts.ATTR_PREFERRED_SAMPLE_CHANNELS,
|
||||
tts.ATTR_PREFERRED_SAMPLE_BYTES,
|
||||
]
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(mock_tts_provider, "_supported_options", supported_options),
|
||||
patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio,
|
||||
):
|
||||
await pipeline_input.execute()
|
||||
|
||||
for event in events:
|
||||
if event.type == assist_pipeline.PipelineEventType.TTS_END:
|
||||
# We must fetch the media URL to trigger the TTS
|
||||
assert event.data
|
||||
await client.get(event.data["tts_output"]["url"])
|
||||
|
||||
assert mock_get_tts_audio.called
|
||||
options = mock_get_tts_audio.call_args_list[0].kwargs["options"]
|
||||
|
||||
# We should have received preferred format options in get_tts_audio
|
||||
assert options.get(tts.ATTR_PREFERRED_FORMAT) == "wav"
|
||||
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_RATE)) == 16000
|
||||
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS)) == 1
|
||||
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_BYTES)) == 2
|
||||
|
||||
|
||||
async def test_tts_dict_preferred_format(
|
||||
hass: HomeAssistant,
|
||||
hass_client: ClientSessionGenerator,
|
||||
mock_tts_provider: MockTTSProvider,
|
||||
init_components,
|
||||
mock_chat_session: chat_session.ChatSession,
|
||||
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||
) -> None:
|
||||
"""Test that preferred format options are given to the TTS system if supported."""
|
||||
client = await hass_client()
|
||||
assert await async_setup_component(hass, media_source.DOMAIN, {})
|
||||
|
||||
events: list[assist_pipeline.PipelineEvent] = []
|
||||
|
||||
pipeline_store = pipeline_data.pipeline_store
|
||||
pipeline_id = pipeline_store.async_get_preferred_item()
|
||||
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
tts_input="This is a test.",
|
||||
session=mock_chat_session,
|
||||
device_id=None,
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
pipeline=pipeline,
|
||||
start_stage=assist_pipeline.PipelineStage.TTS,
|
||||
end_stage=assist_pipeline.PipelineStage.TTS,
|
||||
event_callback=events.append,
|
||||
tts_audio_output={
|
||||
tts.ATTR_PREFERRED_FORMAT: "flac",
|
||||
tts.ATTR_PREFERRED_SAMPLE_RATE: 48000,
|
||||
tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 2,
|
||||
tts.ATTR_PREFERRED_SAMPLE_BYTES: 2,
|
||||
},
|
||||
),
|
||||
)
|
||||
await pipeline_input.validate()
|
||||
|
||||
# Make the TTS provider support preferred format options
|
||||
supported_options = list(mock_tts_provider.supported_options or [])
|
||||
supported_options.extend(
|
||||
[
|
||||
tts.ATTR_PREFERRED_FORMAT,
|
||||
tts.ATTR_PREFERRED_SAMPLE_RATE,
|
||||
tts.ATTR_PREFERRED_SAMPLE_CHANNELS,
|
||||
tts.ATTR_PREFERRED_SAMPLE_BYTES,
|
||||
]
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(mock_tts_provider, "_supported_options", supported_options),
|
||||
patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio,
|
||||
):
|
||||
await pipeline_input.execute()
|
||||
|
||||
for event in events:
|
||||
if event.type == assist_pipeline.PipelineEventType.TTS_END:
|
||||
# We must fetch the media URL to trigger the TTS
|
||||
assert event.data
|
||||
await client.get(event.data["tts_output"]["url"])
|
||||
|
||||
assert mock_get_tts_audio.called
|
||||
options = mock_get_tts_audio.call_args_list[0].kwargs["options"]
|
||||
|
||||
# We should have received preferred format options in get_tts_audio
|
||||
assert options.get(tts.ATTR_PREFERRED_FORMAT) == "flac"
|
||||
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_RATE)) == 48000
|
||||
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS)) == 2
|
||||
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_BYTES)) == 2
|
||||
|
||||
|
||||
async def test_sentence_trigger_overrides_conversation_agent(
|
||||
hass: HomeAssistant,
|
||||
init_components,
|
||||
mock_chat_session: chat_session.ChatSession,
|
||||
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||
) -> None:
|
||||
"""Test that sentence triggers are checked before a non-default conversation agent."""
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
"automation",
|
||||
{
|
||||
"automation": {
|
||||
"trigger": {
|
||||
"platform": "conversation",
|
||||
"command": [
|
||||
"test trigger sentence",
|
||||
],
|
||||
},
|
||||
"action": {
|
||||
"set_conversation_response": "test trigger response",
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
events: list[assist_pipeline.PipelineEvent] = []
|
||||
|
||||
pipeline_store = pipeline_data.pipeline_store
|
||||
pipeline_id = pipeline_store.async_get_preferred_item()
|
||||
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
intent_input="test trigger sentence",
|
||||
session=mock_chat_session,
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
pipeline=pipeline,
|
||||
start_stage=assist_pipeline.PipelineStage.INTENT,
|
||||
end_stage=assist_pipeline.PipelineStage.INTENT,
|
||||
event_callback=events.append,
|
||||
intent_agent="test-agent", # not the default agent
|
||||
),
|
||||
)
|
||||
|
||||
# Ensure prepare succeeds
|
||||
with patch(
|
||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
|
||||
return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"),
|
||||
):
|
||||
await pipeline_input.validate()
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse"
|
||||
) as mock_async_converse:
|
||||
await pipeline_input.execute()
|
||||
|
||||
# Sentence trigger should have been handled
|
||||
mock_async_converse.assert_not_called()
|
||||
|
||||
# Verify sentence trigger response
|
||||
intent_end_event = next(
|
||||
(
|
||||
e
|
||||
for e in events
|
||||
if e.type == assist_pipeline.PipelineEventType.INTENT_END
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert (intent_end_event is not None) and intent_end_event.data
|
||||
assert (
|
||||
intent_end_event.data["intent_output"]["response"]["speech"]["plain"][
|
||||
"speech"
|
||||
]
|
||||
== "test trigger response"
|
||||
)
|
||||
|
||||
|
||||
async def test_prefer_local_intents(
|
||||
hass: HomeAssistant,
|
||||
init_components,
|
||||
mock_chat_session: chat_session.ChatSession,
|
||||
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||
) -> None:
|
||||
"""Test that the default agent is checked first when local intents are preferred."""
|
||||
events: list[assist_pipeline.PipelineEvent] = []
|
||||
|
||||
# Reuse custom sentences in test config
|
||||
class OrderBeerIntentHandler(intent.IntentHandler):
|
||||
intent_type = "OrderBeer"
|
||||
|
||||
async def async_handle(
|
||||
self, intent_obj: intent.Intent
|
||||
) -> intent.IntentResponse:
|
||||
response = intent_obj.create_response()
|
||||
response.async_set_speech("Order confirmed")
|
||||
return response
|
||||
|
||||
handler = OrderBeerIntentHandler()
|
||||
intent.async_register(hass, handler)
|
||||
|
||||
# Fake a test agent and prefer local intents
|
||||
pipeline_store = pipeline_data.pipeline_store
|
||||
pipeline_id = pipeline_store.async_get_preferred_item()
|
||||
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||
await assist_pipeline.pipeline.async_update_pipeline(
|
||||
hass, pipeline, conversation_engine="test-agent", prefer_local_intents=True
|
||||
)
|
||||
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
intent_input="I'd like to order a stout please",
|
||||
session=mock_chat_session,
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
pipeline=pipeline,
|
||||
start_stage=assist_pipeline.PipelineStage.INTENT,
|
||||
end_stage=assist_pipeline.PipelineStage.INTENT,
|
||||
event_callback=events.append,
|
||||
),
|
||||
)
|
||||
|
||||
# Ensure prepare succeeds
|
||||
with patch(
|
||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
|
||||
return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"),
|
||||
):
|
||||
await pipeline_input.validate()
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse"
|
||||
) as mock_async_converse:
|
||||
await pipeline_input.execute()
|
||||
|
||||
# Test agent should not have been called
|
||||
mock_async_converse.assert_not_called()
|
||||
|
||||
# Verify local intent response
|
||||
intent_end_event = next(
|
||||
(
|
||||
e
|
||||
for e in events
|
||||
if e.type == assist_pipeline.PipelineEventType.INTENT_END
|
||||
),
|
||||
None,
|
||||
)
|
||||
assert (intent_end_event is not None) and intent_end_event.data
|
||||
assert (
|
||||
intent_end_event.data["intent_output"]["response"]["speech"]["plain"][
|
||||
"speech"
|
||||
]
|
||||
== "Order confirmed"
|
||||
)
|
||||
|
||||
|
||||
async def test_intent_continue_conversation(
|
||||
hass: HomeAssistant,
|
||||
init_components,
|
||||
mock_chat_session: chat_session.ChatSession,
|
||||
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||
) -> None:
|
||||
"""Test that a conversation agent flagging continue conversation gets response."""
|
||||
events: list[assist_pipeline.PipelineEvent] = []
|
||||
|
||||
# Fake a test agent and prefer local intents
|
||||
pipeline_store = pipeline_data.pipeline_store
|
||||
pipeline_id = pipeline_store.async_get_preferred_item()
|
||||
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||
await assist_pipeline.pipeline.async_update_pipeline(
|
||||
hass, pipeline, conversation_engine="test-agent"
|
||||
)
|
||||
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
intent_input="Set a timer",
|
||||
session=mock_chat_session,
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
pipeline=pipeline,
|
||||
start_stage=assist_pipeline.PipelineStage.INTENT,
|
||||
end_stage=assist_pipeline.PipelineStage.INTENT,
|
||||
event_callback=events.append,
|
||||
),
|
||||
)
|
||||
|
||||
# Ensure prepare succeeds
|
||||
with patch(
|
||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
|
||||
return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"),
|
||||
):
|
||||
await pipeline_input.validate()
|
||||
|
||||
response = intent.IntentResponse("en")
|
||||
response.async_set_speech("For how long?")
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
|
||||
return_value=conversation.ConversationResult(
|
||||
response=response,
|
||||
conversation_id=mock_chat_session.conversation_id,
|
||||
continue_conversation=True,
|
||||
),
|
||||
) as mock_async_converse:
|
||||
await pipeline_input.execute()
|
||||
|
||||
mock_async_converse.assert_called()
|
||||
|
||||
results = [
|
||||
event.data
|
||||
for event in events
|
||||
if event.type
|
||||
in (
|
||||
assist_pipeline.PipelineEventType.INTENT_START,
|
||||
assist_pipeline.PipelineEventType.INTENT_END,
|
||||
)
|
||||
]
|
||||
assert results[1]["intent_output"]["continue_conversation"] is True
|
||||
|
||||
# Change conversation agent to default one and register sentence trigger that should not be called
|
||||
await assist_pipeline.pipeline.async_update_pipeline(
|
||||
hass, pipeline, conversation_engine=None
|
||||
)
|
||||
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||
assert await async_setup_component(
|
||||
hass,
|
||||
"automation",
|
||||
{
|
||||
"automation": {
|
||||
"trigger": {
|
||||
"platform": "conversation",
|
||||
"command": ["Hello"],
|
||||
},
|
||||
"action": {
|
||||
"set_conversation_response": "test trigger response",
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Because we did continue conversation, it should respond to the test agent again.
|
||||
events.clear()
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
intent_input="Hello",
|
||||
session=mock_chat_session,
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
pipeline=pipeline,
|
||||
start_stage=assist_pipeline.PipelineStage.INTENT,
|
||||
end_stage=assist_pipeline.PipelineStage.INTENT,
|
||||
event_callback=events.append,
|
||||
),
|
||||
)
|
||||
|
||||
# Ensure prepare succeeds
|
||||
with patch(
|
||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
|
||||
return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"),
|
||||
) as mock_prepare:
|
||||
await pipeline_input.validate()
|
||||
|
||||
# It requested test agent even if that was not default agent.
|
||||
assert mock_prepare.mock_calls[0][1][1] == "test-agent"
|
||||
|
||||
response = intent.IntentResponse("en")
|
||||
response.async_set_speech("Timer set for 20 minutes")
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
|
||||
return_value=conversation.ConversationResult(
|
||||
response=response,
|
||||
conversation_id=mock_chat_session.conversation_id,
|
||||
),
|
||||
) as mock_async_converse:
|
||||
await pipeline_input.execute()
|
||||
|
||||
mock_async_converse.assert_called()
|
||||
|
||||
# Snapshot will show it was still handled by the test agent and not default agent
|
||||
results = [
|
||||
event.data
|
||||
for event in events
|
||||
if event.type
|
||||
in (
|
||||
assist_pipeline.PipelineEventType.INTENT_START,
|
||||
assist_pipeline.PipelineEventType.INTENT_END,
|
||||
)
|
||||
]
|
||||
assert results[0]["engine"] == "test-agent"
|
||||
assert results[1]["intent_output"]["continue_conversation"] is False
|
||||
|
||||
|
||||
async def test_stt_language_used_instead_of_conversation_language(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
init_components,
|
||||
mock_chat_session: chat_session.ChatSession,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test that the STT language is used first when the conversation language is '*' (all languages)."""
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
events: list[assist_pipeline.PipelineEvent] = []
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline/create",
|
||||
"conversation_engine": "homeassistant",
|
||||
"conversation_language": MATCH_ALL,
|
||||
"language": "en",
|
||||
"name": "test_name",
|
||||
"stt_engine": "test",
|
||||
"stt_language": "en-US",
|
||||
"tts_engine": "test",
|
||||
"tts_language": "en-US",
|
||||
"tts_voice": "Arnold Schwarzenegger",
|
||||
"wake_word_entity": None,
|
||||
"wake_word_id": None,
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
pipeline_id = msg["result"]["id"]
|
||||
pipeline = assist_pipeline.async_get_pipeline(hass, pipeline_id)
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
intent_input="test input",
|
||||
session=mock_chat_session,
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
pipeline=pipeline,
|
||||
start_stage=assist_pipeline.PipelineStage.INTENT,
|
||||
end_stage=assist_pipeline.PipelineStage.INTENT,
|
||||
event_callback=events.append,
|
||||
),
|
||||
)
|
||||
await pipeline_input.validate()
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
|
||||
return_value=conversation.ConversationResult(
|
||||
intent.IntentResponse(pipeline.language)
|
||||
),
|
||||
) as mock_async_converse:
|
||||
await pipeline_input.execute()
|
||||
|
||||
# Check intent start event
|
||||
assert process_events(events) == snapshot
|
||||
intent_start: assist_pipeline.PipelineEvent | None = None
|
||||
for event in events:
|
||||
if event.type == assist_pipeline.PipelineEventType.INTENT_START:
|
||||
intent_start = event
|
||||
break
|
||||
|
||||
assert intent_start is not None
|
||||
|
||||
# STT language (en-US) should be used instead of '*'
|
||||
assert intent_start.data.get("language") == pipeline.stt_language
|
||||
|
||||
# Check input to async_converse
|
||||
mock_async_converse.assert_called_once()
|
||||
assert (
|
||||
mock_async_converse.call_args_list[0].kwargs.get("language")
|
||||
== pipeline.stt_language
|
||||
)
|
||||
|
||||
|
||||
async def test_tts_language_used_instead_of_conversation_language(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
init_components,
|
||||
mock_chat_session: chat_session.ChatSession,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test that the TTS language is used after STT when the conversation language is '*' (all languages)."""
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
events: list[assist_pipeline.PipelineEvent] = []
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline/create",
|
||||
"conversation_engine": "homeassistant",
|
||||
"conversation_language": MATCH_ALL,
|
||||
"language": "en",
|
||||
"name": "test_name",
|
||||
"stt_engine": None,
|
||||
"stt_language": None,
|
||||
"tts_engine": None,
|
||||
"tts_language": "en-us",
|
||||
"tts_voice": "Arnold Schwarzenegger",
|
||||
"wake_word_entity": None,
|
||||
"wake_word_id": None,
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
pipeline_id = msg["result"]["id"]
|
||||
pipeline = assist_pipeline.async_get_pipeline(hass, pipeline_id)
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
intent_input="test input",
|
||||
session=mock_chat_session,
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
pipeline=pipeline,
|
||||
start_stage=assist_pipeline.PipelineStage.INTENT,
|
||||
end_stage=assist_pipeline.PipelineStage.INTENT,
|
||||
event_callback=events.append,
|
||||
),
|
||||
)
|
||||
await pipeline_input.validate()
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
|
||||
return_value=conversation.ConversationResult(
|
||||
intent.IntentResponse(pipeline.language)
|
||||
),
|
||||
) as mock_async_converse:
|
||||
await pipeline_input.execute()
|
||||
|
||||
# Check intent start event
|
||||
assert process_events(events) == snapshot
|
||||
intent_start: assist_pipeline.PipelineEvent | None = None
|
||||
for event in events:
|
||||
if event.type == assist_pipeline.PipelineEventType.INTENT_START:
|
||||
intent_start = event
|
||||
break
|
||||
|
||||
assert intent_start is not None
|
||||
|
||||
# STT language (en-US) should be used instead of '*'
|
||||
assert intent_start.data.get("language") == pipeline.tts_language
|
||||
|
||||
# Check input to async_converse
|
||||
mock_async_converse.assert_called_once()
|
||||
assert (
|
||||
mock_async_converse.call_args_list[0].kwargs.get("language")
|
||||
== pipeline.tts_language
|
||||
)
|
||||
|
||||
|
||||
async def test_pipeline_language_used_instead_of_conversation_language(
|
||||
hass: HomeAssistant,
|
||||
hass_ws_client: WebSocketGenerator,
|
||||
init_components,
|
||||
mock_chat_session: chat_session.ChatSession,
|
||||
snapshot: SnapshotAssertion,
|
||||
) -> None:
|
||||
"""Test that the pipeline language is used last when the conversation language is '*' (all languages)."""
|
||||
client = await hass_ws_client(hass)
|
||||
|
||||
events: list[assist_pipeline.PipelineEvent] = []
|
||||
|
||||
await client.send_json_auto_id(
|
||||
{
|
||||
"type": "assist_pipeline/pipeline/create",
|
||||
"conversation_engine": "homeassistant",
|
||||
"conversation_language": MATCH_ALL,
|
||||
"language": "en",
|
||||
"name": "test_name",
|
||||
"stt_engine": None,
|
||||
"stt_language": None,
|
||||
"tts_engine": None,
|
||||
"tts_language": None,
|
||||
"tts_voice": None,
|
||||
"wake_word_entity": None,
|
||||
"wake_word_id": None,
|
||||
}
|
||||
)
|
||||
msg = await client.receive_json()
|
||||
assert msg["success"]
|
||||
pipeline_id = msg["result"]["id"]
|
||||
pipeline = assist_pipeline.async_get_pipeline(hass, pipeline_id)
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
intent_input="test input",
|
||||
session=mock_chat_session,
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
pipeline=pipeline,
|
||||
start_stage=assist_pipeline.PipelineStage.INTENT,
|
||||
end_stage=assist_pipeline.PipelineStage.INTENT,
|
||||
event_callback=events.append,
|
||||
),
|
||||
)
|
||||
await pipeline_input.validate()
|
||||
|
||||
with patch(
|
||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
|
||||
return_value=conversation.ConversationResult(
|
||||
intent.IntentResponse(pipeline.language)
|
||||
),
|
||||
) as mock_async_converse:
|
||||
await pipeline_input.execute()
|
||||
|
||||
# Check intent start event
|
||||
assert process_events(events) == snapshot
|
||||
intent_start: assist_pipeline.PipelineEvent | None = None
|
||||
for event in events:
|
||||
if event.type == assist_pipeline.PipelineEventType.INTENT_START:
|
||||
intent_start = event
|
||||
break
|
||||
|
||||
assert intent_start is not None
|
||||
|
||||
# STT language (en-US) should be used instead of '*'
|
||||
assert intent_start.data.get("language") == pipeline.language
|
||||
|
||||
# Check input to async_converse
|
||||
mock_async_converse.assert_called_once()
|
||||
assert (
|
||||
mock_async_converse.call_args_list[0].kwargs.get("language")
|
||||
== pipeline.language
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user