mirror of
https://github.com/home-assistant/core.git
synced 2025-07-29 08:07:45 +00:00
Move Assist Pipeline tests to right file (#144696)
This commit is contained in:
parent
b394c07a3d
commit
4faa920318
@ -1,5 +1,10 @@
|
|||||||
"""Tests for the Voice Assistant integration."""
|
"""Tests for the Voice Assistant integration."""
|
||||||
|
|
||||||
|
from dataclasses import asdict
|
||||||
|
from unittest.mock import ANY
|
||||||
|
|
||||||
|
from homeassistant.components import assist_pipeline
|
||||||
|
|
||||||
MANY_LANGUAGES = [
|
MANY_LANGUAGES = [
|
||||||
"ar",
|
"ar",
|
||||||
"bg",
|
"bg",
|
||||||
@ -54,3 +59,16 @@ MANY_LANGUAGES = [
|
|||||||
"zh-hk",
|
"zh-hk",
|
||||||
"zh-tw",
|
"zh-tw",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def process_events(events: list[assist_pipeline.PipelineEvent]) -> list[dict]:
|
||||||
|
"""Process events to remove dynamic values."""
|
||||||
|
processed = []
|
||||||
|
for event in events:
|
||||||
|
as_dict = asdict(event)
|
||||||
|
as_dict.pop("timestamp")
|
||||||
|
if as_dict["type"] == assist_pipeline.PipelineEventType.RUN_START:
|
||||||
|
as_dict["data"]["pipeline"] = ANY
|
||||||
|
processed.append(as_dict)
|
||||||
|
|
||||||
|
return processed
|
||||||
|
@ -461,204 +461,3 @@
|
|||||||
}),
|
}),
|
||||||
])
|
])
|
||||||
# ---
|
# ---
|
||||||
# name: test_pipeline_language_used_instead_of_conversation_language
|
|
||||||
list([
|
|
||||||
dict({
|
|
||||||
'data': dict({
|
|
||||||
'conversation_id': 'mock-ulid',
|
|
||||||
'language': 'en',
|
|
||||||
'pipeline': <ANY>,
|
|
||||||
}),
|
|
||||||
'type': <PipelineEventType.RUN_START: 'run-start'>,
|
|
||||||
}),
|
|
||||||
dict({
|
|
||||||
'data': dict({
|
|
||||||
'conversation_id': 'mock-ulid',
|
|
||||||
'device_id': None,
|
|
||||||
'engine': 'conversation.home_assistant',
|
|
||||||
'intent_input': 'test input',
|
|
||||||
'language': 'en',
|
|
||||||
'prefer_local_intents': False,
|
|
||||||
}),
|
|
||||||
'type': <PipelineEventType.INTENT_START: 'intent-start'>,
|
|
||||||
}),
|
|
||||||
dict({
|
|
||||||
'data': dict({
|
|
||||||
'intent_output': dict({
|
|
||||||
'continue_conversation': False,
|
|
||||||
'conversation_id': <ANY>,
|
|
||||||
'response': dict({
|
|
||||||
'card': dict({
|
|
||||||
}),
|
|
||||||
'data': dict({
|
|
||||||
'failed': list([
|
|
||||||
]),
|
|
||||||
'success': list([
|
|
||||||
]),
|
|
||||||
'targets': list([
|
|
||||||
]),
|
|
||||||
}),
|
|
||||||
'language': 'en',
|
|
||||||
'response_type': 'action_done',
|
|
||||||
'speech': dict({
|
|
||||||
}),
|
|
||||||
}),
|
|
||||||
}),
|
|
||||||
'processed_locally': True,
|
|
||||||
}),
|
|
||||||
'type': <PipelineEventType.INTENT_END: 'intent-end'>,
|
|
||||||
}),
|
|
||||||
dict({
|
|
||||||
'data': None,
|
|
||||||
'type': <PipelineEventType.RUN_END: 'run-end'>,
|
|
||||||
}),
|
|
||||||
])
|
|
||||||
# ---
|
|
||||||
# name: test_stt_language_used_instead_of_conversation_language
|
|
||||||
list([
|
|
||||||
dict({
|
|
||||||
'data': dict({
|
|
||||||
'conversation_id': 'mock-ulid',
|
|
||||||
'language': 'en',
|
|
||||||
'pipeline': <ANY>,
|
|
||||||
}),
|
|
||||||
'type': <PipelineEventType.RUN_START: 'run-start'>,
|
|
||||||
}),
|
|
||||||
dict({
|
|
||||||
'data': dict({
|
|
||||||
'conversation_id': 'mock-ulid',
|
|
||||||
'device_id': None,
|
|
||||||
'engine': 'conversation.home_assistant',
|
|
||||||
'intent_input': 'test input',
|
|
||||||
'language': 'en-US',
|
|
||||||
'prefer_local_intents': False,
|
|
||||||
}),
|
|
||||||
'type': <PipelineEventType.INTENT_START: 'intent-start'>,
|
|
||||||
}),
|
|
||||||
dict({
|
|
||||||
'data': dict({
|
|
||||||
'intent_output': dict({
|
|
||||||
'continue_conversation': False,
|
|
||||||
'conversation_id': <ANY>,
|
|
||||||
'response': dict({
|
|
||||||
'card': dict({
|
|
||||||
}),
|
|
||||||
'data': dict({
|
|
||||||
'failed': list([
|
|
||||||
]),
|
|
||||||
'success': list([
|
|
||||||
]),
|
|
||||||
'targets': list([
|
|
||||||
]),
|
|
||||||
}),
|
|
||||||
'language': 'en',
|
|
||||||
'response_type': 'action_done',
|
|
||||||
'speech': dict({
|
|
||||||
}),
|
|
||||||
}),
|
|
||||||
}),
|
|
||||||
'processed_locally': True,
|
|
||||||
}),
|
|
||||||
'type': <PipelineEventType.INTENT_END: 'intent-end'>,
|
|
||||||
}),
|
|
||||||
dict({
|
|
||||||
'data': None,
|
|
||||||
'type': <PipelineEventType.RUN_END: 'run-end'>,
|
|
||||||
}),
|
|
||||||
])
|
|
||||||
# ---
|
|
||||||
# name: test_tts_language_used_instead_of_conversation_language
|
|
||||||
list([
|
|
||||||
dict({
|
|
||||||
'data': dict({
|
|
||||||
'conversation_id': 'mock-ulid',
|
|
||||||
'language': 'en',
|
|
||||||
'pipeline': <ANY>,
|
|
||||||
}),
|
|
||||||
'type': <PipelineEventType.RUN_START: 'run-start'>,
|
|
||||||
}),
|
|
||||||
dict({
|
|
||||||
'data': dict({
|
|
||||||
'conversation_id': 'mock-ulid',
|
|
||||||
'device_id': None,
|
|
||||||
'engine': 'conversation.home_assistant',
|
|
||||||
'intent_input': 'test input',
|
|
||||||
'language': 'en-us',
|
|
||||||
'prefer_local_intents': False,
|
|
||||||
}),
|
|
||||||
'type': <PipelineEventType.INTENT_START: 'intent-start'>,
|
|
||||||
}),
|
|
||||||
dict({
|
|
||||||
'data': dict({
|
|
||||||
'intent_output': dict({
|
|
||||||
'continue_conversation': False,
|
|
||||||
'conversation_id': <ANY>,
|
|
||||||
'response': dict({
|
|
||||||
'card': dict({
|
|
||||||
}),
|
|
||||||
'data': dict({
|
|
||||||
'failed': list([
|
|
||||||
]),
|
|
||||||
'success': list([
|
|
||||||
]),
|
|
||||||
'targets': list([
|
|
||||||
]),
|
|
||||||
}),
|
|
||||||
'language': 'en',
|
|
||||||
'response_type': 'action_done',
|
|
||||||
'speech': dict({
|
|
||||||
}),
|
|
||||||
}),
|
|
||||||
}),
|
|
||||||
'processed_locally': True,
|
|
||||||
}),
|
|
||||||
'type': <PipelineEventType.INTENT_END: 'intent-end'>,
|
|
||||||
}),
|
|
||||||
dict({
|
|
||||||
'data': None,
|
|
||||||
'type': <PipelineEventType.RUN_END: 'run-end'>,
|
|
||||||
}),
|
|
||||||
])
|
|
||||||
# ---
|
|
||||||
# name: test_wake_word_detection_aborted
|
|
||||||
list([
|
|
||||||
dict({
|
|
||||||
'data': dict({
|
|
||||||
'conversation_id': 'mock-ulid',
|
|
||||||
'language': 'en',
|
|
||||||
'pipeline': <ANY>,
|
|
||||||
'tts_output': dict({
|
|
||||||
'mime_type': 'audio/mpeg',
|
|
||||||
'token': 'mocked-token.mp3',
|
|
||||||
'url': '/api/tts_proxy/mocked-token.mp3',
|
|
||||||
}),
|
|
||||||
}),
|
|
||||||
'type': <PipelineEventType.RUN_START: 'run-start'>,
|
|
||||||
}),
|
|
||||||
dict({
|
|
||||||
'data': dict({
|
|
||||||
'entity_id': 'wake_word.test',
|
|
||||||
'metadata': dict({
|
|
||||||
'bit_rate': <AudioBitRates.BITRATE_16: 16>,
|
|
||||||
'channel': <AudioChannels.CHANNEL_MONO: 1>,
|
|
||||||
'codec': <AudioCodecs.PCM: 'pcm'>,
|
|
||||||
'format': <AudioFormats.WAV: 'wav'>,
|
|
||||||
'sample_rate': <AudioSampleRates.SAMPLERATE_16000: 16000>,
|
|
||||||
}),
|
|
||||||
'timeout': 0,
|
|
||||||
}),
|
|
||||||
'type': <PipelineEventType.WAKE_WORD_START: 'wake_word-start'>,
|
|
||||||
}),
|
|
||||||
dict({
|
|
||||||
'data': dict({
|
|
||||||
'code': 'wake_word_detection_aborted',
|
|
||||||
'message': '',
|
|
||||||
}),
|
|
||||||
'type': <PipelineEventType.ERROR: 'error'>,
|
|
||||||
}),
|
|
||||||
dict({
|
|
||||||
'data': None,
|
|
||||||
'type': <PipelineEventType.RUN_END: 'run-end'>,
|
|
||||||
}),
|
|
||||||
])
|
|
||||||
# ---
|
|
||||||
|
202
tests/components/assist_pipeline/snapshots/test_pipeline.ambr
Normal file
202
tests/components/assist_pipeline/snapshots/test_pipeline.ambr
Normal file
@ -0,0 +1,202 @@
|
|||||||
|
# serializer version: 1
|
||||||
|
# name: test_pipeline_language_used_instead_of_conversation_language
|
||||||
|
list([
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'conversation_id': 'mock-ulid',
|
||||||
|
'language': 'en',
|
||||||
|
'pipeline': <ANY>,
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.RUN_START: 'run-start'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'conversation_id': 'mock-ulid',
|
||||||
|
'device_id': None,
|
||||||
|
'engine': 'conversation.home_assistant',
|
||||||
|
'intent_input': 'test input',
|
||||||
|
'language': 'en',
|
||||||
|
'prefer_local_intents': False,
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_START: 'intent-start'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'intent_output': dict({
|
||||||
|
'continue_conversation': False,
|
||||||
|
'conversation_id': <ANY>,
|
||||||
|
'response': dict({
|
||||||
|
'card': dict({
|
||||||
|
}),
|
||||||
|
'data': dict({
|
||||||
|
'failed': list([
|
||||||
|
]),
|
||||||
|
'success': list([
|
||||||
|
]),
|
||||||
|
'targets': list([
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'language': 'en',
|
||||||
|
'response_type': 'action_done',
|
||||||
|
'speech': dict({
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'processed_locally': True,
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_END: 'intent-end'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': None,
|
||||||
|
'type': <PipelineEventType.RUN_END: 'run-end'>,
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
# ---
|
||||||
|
# name: test_stt_language_used_instead_of_conversation_language
|
||||||
|
list([
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'conversation_id': 'mock-ulid',
|
||||||
|
'language': 'en',
|
||||||
|
'pipeline': <ANY>,
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.RUN_START: 'run-start'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'conversation_id': 'mock-ulid',
|
||||||
|
'device_id': None,
|
||||||
|
'engine': 'conversation.home_assistant',
|
||||||
|
'intent_input': 'test input',
|
||||||
|
'language': 'en-US',
|
||||||
|
'prefer_local_intents': False,
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_START: 'intent-start'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'intent_output': dict({
|
||||||
|
'continue_conversation': False,
|
||||||
|
'conversation_id': <ANY>,
|
||||||
|
'response': dict({
|
||||||
|
'card': dict({
|
||||||
|
}),
|
||||||
|
'data': dict({
|
||||||
|
'failed': list([
|
||||||
|
]),
|
||||||
|
'success': list([
|
||||||
|
]),
|
||||||
|
'targets': list([
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'language': 'en',
|
||||||
|
'response_type': 'action_done',
|
||||||
|
'speech': dict({
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'processed_locally': True,
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_END: 'intent-end'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': None,
|
||||||
|
'type': <PipelineEventType.RUN_END: 'run-end'>,
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
# ---
|
||||||
|
# name: test_tts_language_used_instead_of_conversation_language
|
||||||
|
list([
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'conversation_id': 'mock-ulid',
|
||||||
|
'language': 'en',
|
||||||
|
'pipeline': <ANY>,
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.RUN_START: 'run-start'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'conversation_id': 'mock-ulid',
|
||||||
|
'device_id': None,
|
||||||
|
'engine': 'conversation.home_assistant',
|
||||||
|
'intent_input': 'test input',
|
||||||
|
'language': 'en-us',
|
||||||
|
'prefer_local_intents': False,
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_START: 'intent-start'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'intent_output': dict({
|
||||||
|
'continue_conversation': False,
|
||||||
|
'conversation_id': <ANY>,
|
||||||
|
'response': dict({
|
||||||
|
'card': dict({
|
||||||
|
}),
|
||||||
|
'data': dict({
|
||||||
|
'failed': list([
|
||||||
|
]),
|
||||||
|
'success': list([
|
||||||
|
]),
|
||||||
|
'targets': list([
|
||||||
|
]),
|
||||||
|
}),
|
||||||
|
'language': 'en',
|
||||||
|
'response_type': 'action_done',
|
||||||
|
'speech': dict({
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'processed_locally': True,
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.INTENT_END: 'intent-end'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': None,
|
||||||
|
'type': <PipelineEventType.RUN_END: 'run-end'>,
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
# ---
|
||||||
|
# name: test_wake_word_detection_aborted
|
||||||
|
list([
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'conversation_id': 'mock-ulid',
|
||||||
|
'language': 'en',
|
||||||
|
'pipeline': <ANY>,
|
||||||
|
'tts_output': dict({
|
||||||
|
'mime_type': 'audio/mpeg',
|
||||||
|
'token': 'mocked-token.mp3',
|
||||||
|
'url': '/api/tts_proxy/mocked-token.mp3',
|
||||||
|
}),
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.RUN_START: 'run-start'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'entity_id': 'wake_word.test',
|
||||||
|
'metadata': dict({
|
||||||
|
'bit_rate': <AudioBitRates.BITRATE_16: 16>,
|
||||||
|
'channel': <AudioChannels.CHANNEL_MONO: 1>,
|
||||||
|
'codec': <AudioCodecs.PCM: 'pcm'>,
|
||||||
|
'format': <AudioFormats.WAV: 'wav'>,
|
||||||
|
'sample_rate': <AudioSampleRates.SAMPLERATE_16000: 16000>,
|
||||||
|
}),
|
||||||
|
'timeout': 0,
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.WAKE_WORD_START: 'wake_word-start'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': dict({
|
||||||
|
'code': 'wake_word_detection_aborted',
|
||||||
|
'message': '',
|
||||||
|
}),
|
||||||
|
'type': <PipelineEventType.ERROR: 'error'>,
|
||||||
|
}),
|
||||||
|
dict({
|
||||||
|
'data': None,
|
||||||
|
'type': <PipelineEventType.RUN_END: 'run-end'>,
|
||||||
|
}),
|
||||||
|
])
|
||||||
|
# ---
|
@ -2,44 +2,35 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from dataclasses import asdict
|
|
||||||
import itertools as it
|
import itertools as it
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import tempfile
|
import tempfile
|
||||||
from unittest.mock import ANY, Mock, patch
|
from unittest.mock import Mock, patch
|
||||||
import wave
|
import wave
|
||||||
|
|
||||||
import hass_nabucasa
|
import hass_nabucasa
|
||||||
import pytest
|
import pytest
|
||||||
from syrupy.assertion import SnapshotAssertion
|
from syrupy.assertion import SnapshotAssertion
|
||||||
|
|
||||||
from homeassistant.components import (
|
from homeassistant.components import assist_pipeline, stt
|
||||||
assist_pipeline,
|
|
||||||
conversation,
|
|
||||||
media_source,
|
|
||||||
stt,
|
|
||||||
tts,
|
|
||||||
)
|
|
||||||
from homeassistant.components.assist_pipeline.const import (
|
from homeassistant.components.assist_pipeline.const import (
|
||||||
BYTES_PER_CHUNK,
|
BYTES_PER_CHUNK,
|
||||||
CONF_DEBUG_RECORDING_DIR,
|
CONF_DEBUG_RECORDING_DIR,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
)
|
)
|
||||||
from homeassistant.const import MATCH_ALL
|
|
||||||
from homeassistant.core import Context, HomeAssistant
|
from homeassistant.core import Context, HomeAssistant
|
||||||
from homeassistant.helpers import chat_session, intent
|
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
|
from . import process_events
|
||||||
from .conftest import (
|
from .conftest import (
|
||||||
BYTES_ONE_SECOND,
|
BYTES_ONE_SECOND,
|
||||||
MockSTTProvider,
|
MockSTTProvider,
|
||||||
MockSTTProviderEntity,
|
MockSTTProviderEntity,
|
||||||
MockTTSProvider,
|
|
||||||
MockWakeWordEntity,
|
MockWakeWordEntity,
|
||||||
make_10ms_chunk,
|
make_10ms_chunk,
|
||||||
)
|
)
|
||||||
|
|
||||||
from tests.typing import ClientSessionGenerator, WebSocketGenerator
|
from tests.typing import WebSocketGenerator
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
@ -58,19 +49,6 @@ def mock_tts_token() -> Generator[None]:
|
|||||||
yield
|
yield
|
||||||
|
|
||||||
|
|
||||||
def process_events(events: list[assist_pipeline.PipelineEvent]) -> list[dict]:
|
|
||||||
"""Process events to remove dynamic values."""
|
|
||||||
processed = []
|
|
||||||
for event in events:
|
|
||||||
as_dict = asdict(event)
|
|
||||||
as_dict.pop("timestamp")
|
|
||||||
if as_dict["type"] == assist_pipeline.PipelineEventType.RUN_START:
|
|
||||||
as_dict["data"]["pipeline"] = ANY
|
|
||||||
processed.append(as_dict)
|
|
||||||
|
|
||||||
return processed
|
|
||||||
|
|
||||||
|
|
||||||
async def test_pipeline_from_audio_stream_auto(
|
async def test_pipeline_from_audio_stream_auto(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_stt_provider_entity: MockSTTProviderEntity,
|
mock_stt_provider_entity: MockSTTProviderEntity,
|
||||||
@ -677,823 +655,6 @@ async def test_pipeline_saved_audio_empty_queue(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def test_wake_word_detection_aborted(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
mock_stt_provider: MockSTTProvider,
|
|
||||||
mock_wake_word_provider_entity: MockWakeWordEntity,
|
|
||||||
init_components,
|
|
||||||
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
|
||||||
mock_chat_session: chat_session.ChatSession,
|
|
||||||
snapshot: SnapshotAssertion,
|
|
||||||
) -> None:
|
|
||||||
"""Test creating a pipeline from an audio stream with wake word."""
|
|
||||||
|
|
||||||
events: list[assist_pipeline.PipelineEvent] = []
|
|
||||||
|
|
||||||
async def audio_data():
|
|
||||||
yield make_10ms_chunk(b"silence!")
|
|
||||||
yield make_10ms_chunk(b"wake word!")
|
|
||||||
yield make_10ms_chunk(b"part1")
|
|
||||||
yield make_10ms_chunk(b"part2")
|
|
||||||
yield b""
|
|
||||||
|
|
||||||
pipeline_store = pipeline_data.pipeline_store
|
|
||||||
pipeline_id = pipeline_store.async_get_preferred_item()
|
|
||||||
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
|
||||||
|
|
||||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
|
||||||
session=mock_chat_session,
|
|
||||||
device_id=None,
|
|
||||||
stt_metadata=stt.SpeechMetadata(
|
|
||||||
language="",
|
|
||||||
format=stt.AudioFormats.WAV,
|
|
||||||
codec=stt.AudioCodecs.PCM,
|
|
||||||
bit_rate=stt.AudioBitRates.BITRATE_16,
|
|
||||||
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
|
||||||
channel=stt.AudioChannels.CHANNEL_MONO,
|
|
||||||
),
|
|
||||||
stt_stream=audio_data(),
|
|
||||||
run=assist_pipeline.pipeline.PipelineRun(
|
|
||||||
hass,
|
|
||||||
context=Context(),
|
|
||||||
pipeline=pipeline,
|
|
||||||
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
|
|
||||||
end_stage=assist_pipeline.PipelineStage.TTS,
|
|
||||||
event_callback=events.append,
|
|
||||||
tts_audio_output=None,
|
|
||||||
wake_word_settings=assist_pipeline.WakeWordSettings(
|
|
||||||
audio_seconds_to_buffer=1.5
|
|
||||||
),
|
|
||||||
audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
await pipeline_input.validate()
|
|
||||||
|
|
||||||
updates = pipeline.to_json()
|
|
||||||
updates.pop("id")
|
|
||||||
await pipeline_store.async_update_item(
|
|
||||||
pipeline_id,
|
|
||||||
updates,
|
|
||||||
)
|
|
||||||
await pipeline_input.execute()
|
|
||||||
|
|
||||||
assert process_events(events) == snapshot
|
|
||||||
|
|
||||||
|
|
||||||
def test_pipeline_run_equality(hass: HomeAssistant, init_components) -> None:
|
|
||||||
"""Test that pipeline run equality uses unique id."""
|
|
||||||
|
|
||||||
def event_callback(event):
|
|
||||||
pass
|
|
||||||
|
|
||||||
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass)
|
|
||||||
run_1 = assist_pipeline.pipeline.PipelineRun(
|
|
||||||
hass,
|
|
||||||
context=Context(),
|
|
||||||
pipeline=pipeline,
|
|
||||||
start_stage=assist_pipeline.PipelineStage.STT,
|
|
||||||
end_stage=assist_pipeline.PipelineStage.TTS,
|
|
||||||
event_callback=event_callback,
|
|
||||||
)
|
|
||||||
run_2 = assist_pipeline.pipeline.PipelineRun(
|
|
||||||
hass,
|
|
||||||
context=Context(),
|
|
||||||
pipeline=pipeline,
|
|
||||||
start_stage=assist_pipeline.PipelineStage.STT,
|
|
||||||
end_stage=assist_pipeline.PipelineStage.TTS,
|
|
||||||
event_callback=event_callback,
|
|
||||||
)
|
|
||||||
|
|
||||||
assert run_1 == run_1 # noqa: PLR0124
|
|
||||||
assert run_1 != run_2
|
|
||||||
assert run_1 != 1234
|
|
||||||
|
|
||||||
|
|
||||||
async def test_tts_audio_output(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
hass_client: ClientSessionGenerator,
|
|
||||||
mock_tts_provider: MockTTSProvider,
|
|
||||||
init_components,
|
|
||||||
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
|
||||||
mock_chat_session: chat_session.ChatSession,
|
|
||||||
snapshot: SnapshotAssertion,
|
|
||||||
) -> None:
|
|
||||||
"""Test using tts_audio_output with wav sets options correctly."""
|
|
||||||
client = await hass_client()
|
|
||||||
assert await async_setup_component(hass, media_source.DOMAIN, {})
|
|
||||||
|
|
||||||
events: list[assist_pipeline.PipelineEvent] = []
|
|
||||||
|
|
||||||
pipeline_store = pipeline_data.pipeline_store
|
|
||||||
pipeline_id = pipeline_store.async_get_preferred_item()
|
|
||||||
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
|
||||||
|
|
||||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
|
||||||
tts_input="This is a test.",
|
|
||||||
session=mock_chat_session,
|
|
||||||
device_id=None,
|
|
||||||
run=assist_pipeline.pipeline.PipelineRun(
|
|
||||||
hass,
|
|
||||||
context=Context(),
|
|
||||||
pipeline=pipeline,
|
|
||||||
start_stage=assist_pipeline.PipelineStage.TTS,
|
|
||||||
end_stage=assist_pipeline.PipelineStage.TTS,
|
|
||||||
event_callback=events.append,
|
|
||||||
tts_audio_output="wav",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
await pipeline_input.validate()
|
|
||||||
|
|
||||||
# Verify TTS audio settings
|
|
||||||
assert pipeline_input.run.tts_stream.options is not None
|
|
||||||
assert pipeline_input.run.tts_stream.options.get(tts.ATTR_PREFERRED_FORMAT) == "wav"
|
|
||||||
assert (
|
|
||||||
pipeline_input.run.tts_stream.options.get(tts.ATTR_PREFERRED_SAMPLE_RATE)
|
|
||||||
== 16000
|
|
||||||
)
|
|
||||||
assert (
|
|
||||||
pipeline_input.run.tts_stream.options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS)
|
|
||||||
== 1
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio:
|
|
||||||
await pipeline_input.execute()
|
|
||||||
|
|
||||||
for event in events:
|
|
||||||
if event.type == assist_pipeline.PipelineEventType.TTS_END:
|
|
||||||
# We must fetch the media URL to trigger the TTS
|
|
||||||
assert event.data
|
|
||||||
await client.get(event.data["tts_output"]["url"])
|
|
||||||
|
|
||||||
# Ensure that no unsupported options were passed in
|
|
||||||
assert mock_get_tts_audio.called
|
|
||||||
options = mock_get_tts_audio.call_args_list[0].kwargs["options"]
|
|
||||||
extra_options = set(options).difference(mock_tts_provider.supported_options)
|
|
||||||
assert len(extra_options) == 0, extra_options
|
|
||||||
|
|
||||||
|
|
||||||
async def test_tts_wav_preferred_format(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
hass_client: ClientSessionGenerator,
|
|
||||||
mock_tts_provider: MockTTSProvider,
|
|
||||||
init_components,
|
|
||||||
mock_chat_session: chat_session.ChatSession,
|
|
||||||
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
|
||||||
) -> None:
|
|
||||||
"""Test that preferred format options are given to the TTS system if supported."""
|
|
||||||
client = await hass_client()
|
|
||||||
assert await async_setup_component(hass, media_source.DOMAIN, {})
|
|
||||||
|
|
||||||
events: list[assist_pipeline.PipelineEvent] = []
|
|
||||||
|
|
||||||
pipeline_store = pipeline_data.pipeline_store
|
|
||||||
pipeline_id = pipeline_store.async_get_preferred_item()
|
|
||||||
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
|
||||||
|
|
||||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
|
||||||
tts_input="This is a test.",
|
|
||||||
session=mock_chat_session,
|
|
||||||
device_id=None,
|
|
||||||
run=assist_pipeline.pipeline.PipelineRun(
|
|
||||||
hass,
|
|
||||||
context=Context(),
|
|
||||||
pipeline=pipeline,
|
|
||||||
start_stage=assist_pipeline.PipelineStage.TTS,
|
|
||||||
end_stage=assist_pipeline.PipelineStage.TTS,
|
|
||||||
event_callback=events.append,
|
|
||||||
tts_audio_output="wav",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
await pipeline_input.validate()
|
|
||||||
|
|
||||||
# Make the TTS provider support preferred format options
|
|
||||||
supported_options = list(mock_tts_provider.supported_options or [])
|
|
||||||
supported_options.extend(
|
|
||||||
[
|
|
||||||
tts.ATTR_PREFERRED_FORMAT,
|
|
||||||
tts.ATTR_PREFERRED_SAMPLE_RATE,
|
|
||||||
tts.ATTR_PREFERRED_SAMPLE_CHANNELS,
|
|
||||||
tts.ATTR_PREFERRED_SAMPLE_BYTES,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch.object(mock_tts_provider, "_supported_options", supported_options),
|
|
||||||
patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio,
|
|
||||||
):
|
|
||||||
await pipeline_input.execute()
|
|
||||||
|
|
||||||
for event in events:
|
|
||||||
if event.type == assist_pipeline.PipelineEventType.TTS_END:
|
|
||||||
# We must fetch the media URL to trigger the TTS
|
|
||||||
assert event.data
|
|
||||||
await client.get(event.data["tts_output"]["url"])
|
|
||||||
|
|
||||||
assert mock_get_tts_audio.called
|
|
||||||
options = mock_get_tts_audio.call_args_list[0].kwargs["options"]
|
|
||||||
|
|
||||||
# We should have received preferred format options in get_tts_audio
|
|
||||||
assert options.get(tts.ATTR_PREFERRED_FORMAT) == "wav"
|
|
||||||
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_RATE)) == 16000
|
|
||||||
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS)) == 1
|
|
||||||
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_BYTES)) == 2
|
|
||||||
|
|
||||||
|
|
||||||
async def test_tts_dict_preferred_format(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
hass_client: ClientSessionGenerator,
|
|
||||||
mock_tts_provider: MockTTSProvider,
|
|
||||||
init_components,
|
|
||||||
mock_chat_session: chat_session.ChatSession,
|
|
||||||
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
|
||||||
) -> None:
|
|
||||||
"""Test that preferred format options are given to the TTS system if supported."""
|
|
||||||
client = await hass_client()
|
|
||||||
assert await async_setup_component(hass, media_source.DOMAIN, {})
|
|
||||||
|
|
||||||
events: list[assist_pipeline.PipelineEvent] = []
|
|
||||||
|
|
||||||
pipeline_store = pipeline_data.pipeline_store
|
|
||||||
pipeline_id = pipeline_store.async_get_preferred_item()
|
|
||||||
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
|
||||||
|
|
||||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
|
||||||
tts_input="This is a test.",
|
|
||||||
session=mock_chat_session,
|
|
||||||
device_id=None,
|
|
||||||
run=assist_pipeline.pipeline.PipelineRun(
|
|
||||||
hass,
|
|
||||||
context=Context(),
|
|
||||||
pipeline=pipeline,
|
|
||||||
start_stage=assist_pipeline.PipelineStage.TTS,
|
|
||||||
end_stage=assist_pipeline.PipelineStage.TTS,
|
|
||||||
event_callback=events.append,
|
|
||||||
tts_audio_output={
|
|
||||||
tts.ATTR_PREFERRED_FORMAT: "flac",
|
|
||||||
tts.ATTR_PREFERRED_SAMPLE_RATE: 48000,
|
|
||||||
tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 2,
|
|
||||||
tts.ATTR_PREFERRED_SAMPLE_BYTES: 2,
|
|
||||||
},
|
|
||||||
),
|
|
||||||
)
|
|
||||||
await pipeline_input.validate()
|
|
||||||
|
|
||||||
# Make the TTS provider support preferred format options
|
|
||||||
supported_options = list(mock_tts_provider.supported_options or [])
|
|
||||||
supported_options.extend(
|
|
||||||
[
|
|
||||||
tts.ATTR_PREFERRED_FORMAT,
|
|
||||||
tts.ATTR_PREFERRED_SAMPLE_RATE,
|
|
||||||
tts.ATTR_PREFERRED_SAMPLE_CHANNELS,
|
|
||||||
tts.ATTR_PREFERRED_SAMPLE_BYTES,
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
with (
|
|
||||||
patch.object(mock_tts_provider, "_supported_options", supported_options),
|
|
||||||
patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio,
|
|
||||||
):
|
|
||||||
await pipeline_input.execute()
|
|
||||||
|
|
||||||
for event in events:
|
|
||||||
if event.type == assist_pipeline.PipelineEventType.TTS_END:
|
|
||||||
# We must fetch the media URL to trigger the TTS
|
|
||||||
assert event.data
|
|
||||||
await client.get(event.data["tts_output"]["url"])
|
|
||||||
|
|
||||||
assert mock_get_tts_audio.called
|
|
||||||
options = mock_get_tts_audio.call_args_list[0].kwargs["options"]
|
|
||||||
|
|
||||||
# We should have received preferred format options in get_tts_audio
|
|
||||||
assert options.get(tts.ATTR_PREFERRED_FORMAT) == "flac"
|
|
||||||
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_RATE)) == 48000
|
|
||||||
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS)) == 2
|
|
||||||
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_BYTES)) == 2
|
|
||||||
|
|
||||||
|
|
||||||
async def test_sentence_trigger_overrides_conversation_agent(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
init_components,
|
|
||||||
mock_chat_session: chat_session.ChatSession,
|
|
||||||
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
|
||||||
) -> None:
|
|
||||||
"""Test that sentence triggers are checked before a non-default conversation agent."""
|
|
||||||
assert await async_setup_component(
|
|
||||||
hass,
|
|
||||||
"automation",
|
|
||||||
{
|
|
||||||
"automation": {
|
|
||||||
"trigger": {
|
|
||||||
"platform": "conversation",
|
|
||||||
"command": [
|
|
||||||
"test trigger sentence",
|
|
||||||
],
|
|
||||||
},
|
|
||||||
"action": {
|
|
||||||
"set_conversation_response": "test trigger response",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
events: list[assist_pipeline.PipelineEvent] = []
|
|
||||||
|
|
||||||
pipeline_store = pipeline_data.pipeline_store
|
|
||||||
pipeline_id = pipeline_store.async_get_preferred_item()
|
|
||||||
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
|
||||||
|
|
||||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
|
||||||
intent_input="test trigger sentence",
|
|
||||||
session=mock_chat_session,
|
|
||||||
run=assist_pipeline.pipeline.PipelineRun(
|
|
||||||
hass,
|
|
||||||
context=Context(),
|
|
||||||
pipeline=pipeline,
|
|
||||||
start_stage=assist_pipeline.PipelineStage.INTENT,
|
|
||||||
end_stage=assist_pipeline.PipelineStage.INTENT,
|
|
||||||
event_callback=events.append,
|
|
||||||
intent_agent="test-agent", # not the default agent
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Ensure prepare succeeds
|
|
||||||
with patch(
|
|
||||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
|
|
||||||
return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"),
|
|
||||||
):
|
|
||||||
await pipeline_input.validate()
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse"
|
|
||||||
) as mock_async_converse:
|
|
||||||
await pipeline_input.execute()
|
|
||||||
|
|
||||||
# Sentence trigger should have been handled
|
|
||||||
mock_async_converse.assert_not_called()
|
|
||||||
|
|
||||||
# Verify sentence trigger response
|
|
||||||
intent_end_event = next(
|
|
||||||
(
|
|
||||||
e
|
|
||||||
for e in events
|
|
||||||
if e.type == assist_pipeline.PipelineEventType.INTENT_END
|
|
||||||
),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
assert (intent_end_event is not None) and intent_end_event.data
|
|
||||||
assert (
|
|
||||||
intent_end_event.data["intent_output"]["response"]["speech"]["plain"][
|
|
||||||
"speech"
|
|
||||||
]
|
|
||||||
== "test trigger response"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def test_prefer_local_intents(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
init_components,
|
|
||||||
mock_chat_session: chat_session.ChatSession,
|
|
||||||
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
|
||||||
) -> None:
|
|
||||||
"""Test that the default agent is checked first when local intents are preferred."""
|
|
||||||
events: list[assist_pipeline.PipelineEvent] = []
|
|
||||||
|
|
||||||
# Reuse custom sentences in test config
|
|
||||||
class OrderBeerIntentHandler(intent.IntentHandler):
|
|
||||||
intent_type = "OrderBeer"
|
|
||||||
|
|
||||||
async def async_handle(
|
|
||||||
self, intent_obj: intent.Intent
|
|
||||||
) -> intent.IntentResponse:
|
|
||||||
response = intent_obj.create_response()
|
|
||||||
response.async_set_speech("Order confirmed")
|
|
||||||
return response
|
|
||||||
|
|
||||||
handler = OrderBeerIntentHandler()
|
|
||||||
intent.async_register(hass, handler)
|
|
||||||
|
|
||||||
# Fake a test agent and prefer local intents
|
|
||||||
pipeline_store = pipeline_data.pipeline_store
|
|
||||||
pipeline_id = pipeline_store.async_get_preferred_item()
|
|
||||||
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
|
||||||
await assist_pipeline.pipeline.async_update_pipeline(
|
|
||||||
hass, pipeline, conversation_engine="test-agent", prefer_local_intents=True
|
|
||||||
)
|
|
||||||
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
|
||||||
|
|
||||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
|
||||||
intent_input="I'd like to order a stout please",
|
|
||||||
session=mock_chat_session,
|
|
||||||
run=assist_pipeline.pipeline.PipelineRun(
|
|
||||||
hass,
|
|
||||||
context=Context(),
|
|
||||||
pipeline=pipeline,
|
|
||||||
start_stage=assist_pipeline.PipelineStage.INTENT,
|
|
||||||
end_stage=assist_pipeline.PipelineStage.INTENT,
|
|
||||||
event_callback=events.append,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Ensure prepare succeeds
|
|
||||||
with patch(
|
|
||||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
|
|
||||||
return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"),
|
|
||||||
):
|
|
||||||
await pipeline_input.validate()
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse"
|
|
||||||
) as mock_async_converse:
|
|
||||||
await pipeline_input.execute()
|
|
||||||
|
|
||||||
# Test agent should not have been called
|
|
||||||
mock_async_converse.assert_not_called()
|
|
||||||
|
|
||||||
# Verify local intent response
|
|
||||||
intent_end_event = next(
|
|
||||||
(
|
|
||||||
e
|
|
||||||
for e in events
|
|
||||||
if e.type == assist_pipeline.PipelineEventType.INTENT_END
|
|
||||||
),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
assert (intent_end_event is not None) and intent_end_event.data
|
|
||||||
assert (
|
|
||||||
intent_end_event.data["intent_output"]["response"]["speech"]["plain"][
|
|
||||||
"speech"
|
|
||||||
]
|
|
||||||
== "Order confirmed"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def test_intent_continue_conversation(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
init_components,
|
|
||||||
mock_chat_session: chat_session.ChatSession,
|
|
||||||
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
|
||||||
) -> None:
|
|
||||||
"""Test that a conversation agent flagging continue conversation gets response."""
|
|
||||||
events: list[assist_pipeline.PipelineEvent] = []
|
|
||||||
|
|
||||||
# Fake a test agent and prefer local intents
|
|
||||||
pipeline_store = pipeline_data.pipeline_store
|
|
||||||
pipeline_id = pipeline_store.async_get_preferred_item()
|
|
||||||
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
|
||||||
await assist_pipeline.pipeline.async_update_pipeline(
|
|
||||||
hass, pipeline, conversation_engine="test-agent"
|
|
||||||
)
|
|
||||||
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
|
||||||
|
|
||||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
|
||||||
intent_input="Set a timer",
|
|
||||||
session=mock_chat_session,
|
|
||||||
run=assist_pipeline.pipeline.PipelineRun(
|
|
||||||
hass,
|
|
||||||
context=Context(),
|
|
||||||
pipeline=pipeline,
|
|
||||||
start_stage=assist_pipeline.PipelineStage.INTENT,
|
|
||||||
end_stage=assist_pipeline.PipelineStage.INTENT,
|
|
||||||
event_callback=events.append,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Ensure prepare succeeds
|
|
||||||
with patch(
|
|
||||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
|
|
||||||
return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"),
|
|
||||||
):
|
|
||||||
await pipeline_input.validate()
|
|
||||||
|
|
||||||
response = intent.IntentResponse("en")
|
|
||||||
response.async_set_speech("For how long?")
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
|
|
||||||
return_value=conversation.ConversationResult(
|
|
||||||
response=response,
|
|
||||||
conversation_id=mock_chat_session.conversation_id,
|
|
||||||
continue_conversation=True,
|
|
||||||
),
|
|
||||||
) as mock_async_converse:
|
|
||||||
await pipeline_input.execute()
|
|
||||||
|
|
||||||
mock_async_converse.assert_called()
|
|
||||||
|
|
||||||
results = [
|
|
||||||
event.data
|
|
||||||
for event in events
|
|
||||||
if event.type
|
|
||||||
in (
|
|
||||||
assist_pipeline.PipelineEventType.INTENT_START,
|
|
||||||
assist_pipeline.PipelineEventType.INTENT_END,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
assert results[1]["intent_output"]["continue_conversation"] is True
|
|
||||||
|
|
||||||
# Change conversation agent to default one and register sentence trigger that should not be called
|
|
||||||
await assist_pipeline.pipeline.async_update_pipeline(
|
|
||||||
hass, pipeline, conversation_engine=None
|
|
||||||
)
|
|
||||||
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
|
||||||
assert await async_setup_component(
|
|
||||||
hass,
|
|
||||||
"automation",
|
|
||||||
{
|
|
||||||
"automation": {
|
|
||||||
"trigger": {
|
|
||||||
"platform": "conversation",
|
|
||||||
"command": ["Hello"],
|
|
||||||
},
|
|
||||||
"action": {
|
|
||||||
"set_conversation_response": "test trigger response",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
# Because we did continue conversation, it should respond to the test agent again.
|
|
||||||
events.clear()
|
|
||||||
|
|
||||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
|
||||||
intent_input="Hello",
|
|
||||||
session=mock_chat_session,
|
|
||||||
run=assist_pipeline.pipeline.PipelineRun(
|
|
||||||
hass,
|
|
||||||
context=Context(),
|
|
||||||
pipeline=pipeline,
|
|
||||||
start_stage=assist_pipeline.PipelineStage.INTENT,
|
|
||||||
end_stage=assist_pipeline.PipelineStage.INTENT,
|
|
||||||
event_callback=events.append,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Ensure prepare succeeds
|
|
||||||
with patch(
|
|
||||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
|
|
||||||
return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"),
|
|
||||||
) as mock_prepare:
|
|
||||||
await pipeline_input.validate()
|
|
||||||
|
|
||||||
# It requested test agent even if that was not default agent.
|
|
||||||
assert mock_prepare.mock_calls[0][1][1] == "test-agent"
|
|
||||||
|
|
||||||
response = intent.IntentResponse("en")
|
|
||||||
response.async_set_speech("Timer set for 20 minutes")
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
|
|
||||||
return_value=conversation.ConversationResult(
|
|
||||||
response=response,
|
|
||||||
conversation_id=mock_chat_session.conversation_id,
|
|
||||||
),
|
|
||||||
) as mock_async_converse:
|
|
||||||
await pipeline_input.execute()
|
|
||||||
|
|
||||||
mock_async_converse.assert_called()
|
|
||||||
|
|
||||||
# Snapshot will show it was still handled by the test agent and not default agent
|
|
||||||
results = [
|
|
||||||
event.data
|
|
||||||
for event in events
|
|
||||||
if event.type
|
|
||||||
in (
|
|
||||||
assist_pipeline.PipelineEventType.INTENT_START,
|
|
||||||
assist_pipeline.PipelineEventType.INTENT_END,
|
|
||||||
)
|
|
||||||
]
|
|
||||||
assert results[0]["engine"] == "test-agent"
|
|
||||||
assert results[1]["intent_output"]["continue_conversation"] is False
|
|
||||||
|
|
||||||
|
|
||||||
async def test_stt_language_used_instead_of_conversation_language(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
hass_ws_client: WebSocketGenerator,
|
|
||||||
init_components,
|
|
||||||
mock_chat_session: chat_session.ChatSession,
|
|
||||||
snapshot: SnapshotAssertion,
|
|
||||||
) -> None:
|
|
||||||
"""Test that the STT language is used first when the conversation language is '*' (all languages)."""
|
|
||||||
client = await hass_ws_client(hass)
|
|
||||||
|
|
||||||
events: list[assist_pipeline.PipelineEvent] = []
|
|
||||||
|
|
||||||
await client.send_json_auto_id(
|
|
||||||
{
|
|
||||||
"type": "assist_pipeline/pipeline/create",
|
|
||||||
"conversation_engine": "homeassistant",
|
|
||||||
"conversation_language": MATCH_ALL,
|
|
||||||
"language": "en",
|
|
||||||
"name": "test_name",
|
|
||||||
"stt_engine": "test",
|
|
||||||
"stt_language": "en-US",
|
|
||||||
"tts_engine": "test",
|
|
||||||
"tts_language": "en-US",
|
|
||||||
"tts_voice": "Arnold Schwarzenegger",
|
|
||||||
"wake_word_entity": None,
|
|
||||||
"wake_word_id": None,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
msg = await client.receive_json()
|
|
||||||
assert msg["success"]
|
|
||||||
pipeline_id = msg["result"]["id"]
|
|
||||||
pipeline = assist_pipeline.async_get_pipeline(hass, pipeline_id)
|
|
||||||
|
|
||||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
|
||||||
intent_input="test input",
|
|
||||||
session=mock_chat_session,
|
|
||||||
run=assist_pipeline.pipeline.PipelineRun(
|
|
||||||
hass,
|
|
||||||
context=Context(),
|
|
||||||
pipeline=pipeline,
|
|
||||||
start_stage=assist_pipeline.PipelineStage.INTENT,
|
|
||||||
end_stage=assist_pipeline.PipelineStage.INTENT,
|
|
||||||
event_callback=events.append,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
await pipeline_input.validate()
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
|
|
||||||
return_value=conversation.ConversationResult(
|
|
||||||
intent.IntentResponse(pipeline.language)
|
|
||||||
),
|
|
||||||
) as mock_async_converse:
|
|
||||||
await pipeline_input.execute()
|
|
||||||
|
|
||||||
# Check intent start event
|
|
||||||
assert process_events(events) == snapshot
|
|
||||||
intent_start: assist_pipeline.PipelineEvent | None = None
|
|
||||||
for event in events:
|
|
||||||
if event.type == assist_pipeline.PipelineEventType.INTENT_START:
|
|
||||||
intent_start = event
|
|
||||||
break
|
|
||||||
|
|
||||||
assert intent_start is not None
|
|
||||||
|
|
||||||
# STT language (en-US) should be used instead of '*'
|
|
||||||
assert intent_start.data.get("language") == pipeline.stt_language
|
|
||||||
|
|
||||||
# Check input to async_converse
|
|
||||||
mock_async_converse.assert_called_once()
|
|
||||||
assert (
|
|
||||||
mock_async_converse.call_args_list[0].kwargs.get("language")
|
|
||||||
== pipeline.stt_language
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def test_tts_language_used_instead_of_conversation_language(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
hass_ws_client: WebSocketGenerator,
|
|
||||||
init_components,
|
|
||||||
mock_chat_session: chat_session.ChatSession,
|
|
||||||
snapshot: SnapshotAssertion,
|
|
||||||
) -> None:
|
|
||||||
"""Test that the TTS language is used after STT when the conversation language is '*' (all languages)."""
|
|
||||||
client = await hass_ws_client(hass)
|
|
||||||
|
|
||||||
events: list[assist_pipeline.PipelineEvent] = []
|
|
||||||
|
|
||||||
await client.send_json_auto_id(
|
|
||||||
{
|
|
||||||
"type": "assist_pipeline/pipeline/create",
|
|
||||||
"conversation_engine": "homeassistant",
|
|
||||||
"conversation_language": MATCH_ALL,
|
|
||||||
"language": "en",
|
|
||||||
"name": "test_name",
|
|
||||||
"stt_engine": None,
|
|
||||||
"stt_language": None,
|
|
||||||
"tts_engine": None,
|
|
||||||
"tts_language": "en-us",
|
|
||||||
"tts_voice": "Arnold Schwarzenegger",
|
|
||||||
"wake_word_entity": None,
|
|
||||||
"wake_word_id": None,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
msg = await client.receive_json()
|
|
||||||
assert msg["success"]
|
|
||||||
pipeline_id = msg["result"]["id"]
|
|
||||||
pipeline = assist_pipeline.async_get_pipeline(hass, pipeline_id)
|
|
||||||
|
|
||||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
|
||||||
intent_input="test input",
|
|
||||||
session=mock_chat_session,
|
|
||||||
run=assist_pipeline.pipeline.PipelineRun(
|
|
||||||
hass,
|
|
||||||
context=Context(),
|
|
||||||
pipeline=pipeline,
|
|
||||||
start_stage=assist_pipeline.PipelineStage.INTENT,
|
|
||||||
end_stage=assist_pipeline.PipelineStage.INTENT,
|
|
||||||
event_callback=events.append,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
await pipeline_input.validate()
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
|
|
||||||
return_value=conversation.ConversationResult(
|
|
||||||
intent.IntentResponse(pipeline.language)
|
|
||||||
),
|
|
||||||
) as mock_async_converse:
|
|
||||||
await pipeline_input.execute()
|
|
||||||
|
|
||||||
# Check intent start event
|
|
||||||
assert process_events(events) == snapshot
|
|
||||||
intent_start: assist_pipeline.PipelineEvent | None = None
|
|
||||||
for event in events:
|
|
||||||
if event.type == assist_pipeline.PipelineEventType.INTENT_START:
|
|
||||||
intent_start = event
|
|
||||||
break
|
|
||||||
|
|
||||||
assert intent_start is not None
|
|
||||||
|
|
||||||
# STT language (en-US) should be used instead of '*'
|
|
||||||
assert intent_start.data.get("language") == pipeline.tts_language
|
|
||||||
|
|
||||||
# Check input to async_converse
|
|
||||||
mock_async_converse.assert_called_once()
|
|
||||||
assert (
|
|
||||||
mock_async_converse.call_args_list[0].kwargs.get("language")
|
|
||||||
== pipeline.tts_language
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def test_pipeline_language_used_instead_of_conversation_language(
|
|
||||||
hass: HomeAssistant,
|
|
||||||
hass_ws_client: WebSocketGenerator,
|
|
||||||
init_components,
|
|
||||||
mock_chat_session: chat_session.ChatSession,
|
|
||||||
snapshot: SnapshotAssertion,
|
|
||||||
) -> None:
|
|
||||||
"""Test that the pipeline language is used last when the conversation language is '*' (all languages)."""
|
|
||||||
client = await hass_ws_client(hass)
|
|
||||||
|
|
||||||
events: list[assist_pipeline.PipelineEvent] = []
|
|
||||||
|
|
||||||
await client.send_json_auto_id(
|
|
||||||
{
|
|
||||||
"type": "assist_pipeline/pipeline/create",
|
|
||||||
"conversation_engine": "homeassistant",
|
|
||||||
"conversation_language": MATCH_ALL,
|
|
||||||
"language": "en",
|
|
||||||
"name": "test_name",
|
|
||||||
"stt_engine": None,
|
|
||||||
"stt_language": None,
|
|
||||||
"tts_engine": None,
|
|
||||||
"tts_language": None,
|
|
||||||
"tts_voice": None,
|
|
||||||
"wake_word_entity": None,
|
|
||||||
"wake_word_id": None,
|
|
||||||
}
|
|
||||||
)
|
|
||||||
msg = await client.receive_json()
|
|
||||||
assert msg["success"]
|
|
||||||
pipeline_id = msg["result"]["id"]
|
|
||||||
pipeline = assist_pipeline.async_get_pipeline(hass, pipeline_id)
|
|
||||||
|
|
||||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
|
||||||
intent_input="test input",
|
|
||||||
session=mock_chat_session,
|
|
||||||
run=assist_pipeline.pipeline.PipelineRun(
|
|
||||||
hass,
|
|
||||||
context=Context(),
|
|
||||||
pipeline=pipeline,
|
|
||||||
start_stage=assist_pipeline.PipelineStage.INTENT,
|
|
||||||
end_stage=assist_pipeline.PipelineStage.INTENT,
|
|
||||||
event_callback=events.append,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
await pipeline_input.validate()
|
|
||||||
|
|
||||||
with patch(
|
|
||||||
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
|
|
||||||
return_value=conversation.ConversationResult(
|
|
||||||
intent.IntentResponse(pipeline.language)
|
|
||||||
),
|
|
||||||
) as mock_async_converse:
|
|
||||||
await pipeline_input.execute()
|
|
||||||
|
|
||||||
# Check intent start event
|
|
||||||
assert process_events(events) == snapshot
|
|
||||||
intent_start: assist_pipeline.PipelineEvent | None = None
|
|
||||||
for event in events:
|
|
||||||
if event.type == assist_pipeline.PipelineEventType.INTENT_START:
|
|
||||||
intent_start = event
|
|
||||||
break
|
|
||||||
|
|
||||||
assert intent_start is not None
|
|
||||||
|
|
||||||
# STT language (en-US) should be used instead of '*'
|
|
||||||
assert intent_start.data.get("language") == pipeline.language
|
|
||||||
|
|
||||||
# Check input to async_converse
|
|
||||||
mock_async_converse.assert_called_once()
|
|
||||||
assert (
|
|
||||||
mock_async_converse.call_args_list[0].kwargs.get("language")
|
|
||||||
== pipeline.language
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def test_pipeline_from_audio_stream_with_cloud_auth_fail(
|
async def test_pipeline_from_audio_stream_with_cloud_auth_fail(
|
||||||
hass: HomeAssistant,
|
hass: HomeAssistant,
|
||||||
mock_stt_provider_entity: MockSTTProviderEntity,
|
mock_stt_provider_entity: MockSTTProviderEntity,
|
||||||
|
@ -1,13 +1,20 @@
|
|||||||
"""Websocket tests for Voice Assistant integration."""
|
"""Websocket tests for Voice Assistant integration."""
|
||||||
|
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator, Generator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import ANY, patch
|
from unittest.mock import ANY, Mock, patch
|
||||||
|
|
||||||
from hassil.recognize import Intent, IntentData, RecognizeResult
|
from hassil.recognize import Intent, IntentData, RecognizeResult
|
||||||
import pytest
|
import pytest
|
||||||
|
from syrupy.assertion import SnapshotAssertion
|
||||||
|
|
||||||
from homeassistant.components import conversation
|
from homeassistant.components import (
|
||||||
|
assist_pipeline,
|
||||||
|
conversation,
|
||||||
|
media_source,
|
||||||
|
stt,
|
||||||
|
tts,
|
||||||
|
)
|
||||||
from homeassistant.components.assist_pipeline.const import DOMAIN
|
from homeassistant.components.assist_pipeline.const import DOMAIN
|
||||||
from homeassistant.components.assist_pipeline.pipeline import (
|
from homeassistant.components.assist_pipeline.pipeline import (
|
||||||
STORAGE_KEY,
|
STORAGE_KEY,
|
||||||
@ -24,14 +31,22 @@ from homeassistant.components.assist_pipeline.pipeline import (
|
|||||||
async_migrate_engine,
|
async_migrate_engine,
|
||||||
async_update_pipeline,
|
async_update_pipeline,
|
||||||
)
|
)
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.const import MATCH_ALL
|
||||||
from homeassistant.helpers import intent
|
from homeassistant.core import Context, HomeAssistant
|
||||||
|
from homeassistant.helpers import chat_session, intent
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
|
|
||||||
from . import MANY_LANGUAGES
|
from . import MANY_LANGUAGES, process_events
|
||||||
from .conftest import MockSTTProviderEntity, MockTTSProvider
|
from .conftest import (
|
||||||
|
MockSTTProvider,
|
||||||
|
MockSTTProviderEntity,
|
||||||
|
MockTTSProvider,
|
||||||
|
MockWakeWordEntity,
|
||||||
|
make_10ms_chunk,
|
||||||
|
)
|
||||||
|
|
||||||
from tests.common import flush_store
|
from tests.common import flush_store
|
||||||
|
from tests.typing import ClientSessionGenerator, WebSocketGenerator
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
@ -119,6 +134,22 @@ async def test_load_pipelines(hass: HomeAssistant) -> None:
|
|||||||
assert store1.async_get_preferred_item() == store2.async_get_preferred_item()
|
assert store1.async_get_preferred_item() == store2.async_get_preferred_item()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_chat_session_id() -> Generator[Mock]:
|
||||||
|
"""Mock the conversation ID of chat sessions."""
|
||||||
|
with patch(
|
||||||
|
"homeassistant.helpers.chat_session.ulid_now", return_value="mock-ulid"
|
||||||
|
) as mock_ulid_now:
|
||||||
|
yield mock_ulid_now
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_tts_token() -> Generator[None]:
|
||||||
|
"""Mock the TTS token for URLs."""
|
||||||
|
with patch("secrets.token_urlsafe", return_value="mocked-token"):
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
async def test_loading_pipelines_from_storage(
|
async def test_loading_pipelines_from_storage(
|
||||||
hass: HomeAssistant, hass_storage: dict[str, Any]
|
hass: HomeAssistant, hass_storage: dict[str, Any]
|
||||||
) -> None:
|
) -> None:
|
||||||
@ -697,3 +728,820 @@ def test_fallback_intent_filter() -> None:
|
|||||||
)
|
)
|
||||||
is False
|
is False
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_wake_word_detection_aborted(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
mock_stt_provider: MockSTTProvider,
|
||||||
|
mock_wake_word_provider_entity: MockWakeWordEntity,
|
||||||
|
init_components,
|
||||||
|
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||||
|
mock_chat_session: chat_session.ChatSession,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
|
) -> None:
|
||||||
|
"""Test wake word stream is first detected, then aborted."""
|
||||||
|
|
||||||
|
events: list[assist_pipeline.PipelineEvent] = []
|
||||||
|
|
||||||
|
async def audio_data():
|
||||||
|
yield make_10ms_chunk(b"silence!")
|
||||||
|
yield make_10ms_chunk(b"wake word!")
|
||||||
|
yield make_10ms_chunk(b"part1")
|
||||||
|
yield make_10ms_chunk(b"part2")
|
||||||
|
yield b""
|
||||||
|
|
||||||
|
pipeline_store = pipeline_data.pipeline_store
|
||||||
|
pipeline_id = pipeline_store.async_get_preferred_item()
|
||||||
|
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||||
|
|
||||||
|
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||||
|
session=mock_chat_session,
|
||||||
|
device_id=None,
|
||||||
|
stt_metadata=stt.SpeechMetadata(
|
||||||
|
language="",
|
||||||
|
format=stt.AudioFormats.WAV,
|
||||||
|
codec=stt.AudioCodecs.PCM,
|
||||||
|
bit_rate=stt.AudioBitRates.BITRATE_16,
|
||||||
|
sample_rate=stt.AudioSampleRates.SAMPLERATE_16000,
|
||||||
|
channel=stt.AudioChannels.CHANNEL_MONO,
|
||||||
|
),
|
||||||
|
stt_stream=audio_data(),
|
||||||
|
run=assist_pipeline.pipeline.PipelineRun(
|
||||||
|
hass,
|
||||||
|
context=Context(),
|
||||||
|
pipeline=pipeline,
|
||||||
|
start_stage=assist_pipeline.PipelineStage.WAKE_WORD,
|
||||||
|
end_stage=assist_pipeline.PipelineStage.TTS,
|
||||||
|
event_callback=events.append,
|
||||||
|
tts_audio_output=None,
|
||||||
|
wake_word_settings=assist_pipeline.WakeWordSettings(
|
||||||
|
audio_seconds_to_buffer=1.5
|
||||||
|
),
|
||||||
|
audio_settings=assist_pipeline.AudioSettings(is_vad_enabled=False),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await pipeline_input.validate()
|
||||||
|
|
||||||
|
updates = pipeline.to_json()
|
||||||
|
updates.pop("id")
|
||||||
|
await pipeline_store.async_update_item(
|
||||||
|
pipeline_id,
|
||||||
|
updates,
|
||||||
|
)
|
||||||
|
await pipeline_input.execute()
|
||||||
|
|
||||||
|
assert process_events(events) == snapshot
|
||||||
|
|
||||||
|
|
||||||
|
def test_pipeline_run_equality(hass: HomeAssistant, init_components) -> None:
|
||||||
|
"""Test that pipeline run equality uses unique id."""
|
||||||
|
|
||||||
|
def event_callback(event):
|
||||||
|
pass
|
||||||
|
|
||||||
|
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass)
|
||||||
|
run_1 = assist_pipeline.pipeline.PipelineRun(
|
||||||
|
hass,
|
||||||
|
context=Context(),
|
||||||
|
pipeline=pipeline,
|
||||||
|
start_stage=assist_pipeline.PipelineStage.STT,
|
||||||
|
end_stage=assist_pipeline.PipelineStage.TTS,
|
||||||
|
event_callback=event_callback,
|
||||||
|
)
|
||||||
|
run_2 = assist_pipeline.pipeline.PipelineRun(
|
||||||
|
hass,
|
||||||
|
context=Context(),
|
||||||
|
pipeline=pipeline,
|
||||||
|
start_stage=assist_pipeline.PipelineStage.STT,
|
||||||
|
end_stage=assist_pipeline.PipelineStage.TTS,
|
||||||
|
event_callback=event_callback,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert run_1 == run_1 # noqa: PLR0124
|
||||||
|
assert run_1 != run_2
|
||||||
|
assert run_1 != 1234
|
||||||
|
|
||||||
|
|
||||||
|
async def test_tts_audio_output(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
hass_client: ClientSessionGenerator,
|
||||||
|
mock_tts_provider: MockTTSProvider,
|
||||||
|
init_components,
|
||||||
|
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||||
|
mock_chat_session: chat_session.ChatSession,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
|
) -> None:
|
||||||
|
"""Test using tts_audio_output with wav sets options correctly."""
|
||||||
|
client = await hass_client()
|
||||||
|
assert await async_setup_component(hass, media_source.DOMAIN, {})
|
||||||
|
|
||||||
|
events: list[assist_pipeline.PipelineEvent] = []
|
||||||
|
|
||||||
|
pipeline_store = pipeline_data.pipeline_store
|
||||||
|
pipeline_id = pipeline_store.async_get_preferred_item()
|
||||||
|
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||||
|
|
||||||
|
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||||
|
tts_input="This is a test.",
|
||||||
|
session=mock_chat_session,
|
||||||
|
device_id=None,
|
||||||
|
run=assist_pipeline.pipeline.PipelineRun(
|
||||||
|
hass,
|
||||||
|
context=Context(),
|
||||||
|
pipeline=pipeline,
|
||||||
|
start_stage=assist_pipeline.PipelineStage.TTS,
|
||||||
|
end_stage=assist_pipeline.PipelineStage.TTS,
|
||||||
|
event_callback=events.append,
|
||||||
|
tts_audio_output="wav",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await pipeline_input.validate()
|
||||||
|
|
||||||
|
# Verify TTS audio settings
|
||||||
|
assert pipeline_input.run.tts_stream.options is not None
|
||||||
|
assert pipeline_input.run.tts_stream.options.get(tts.ATTR_PREFERRED_FORMAT) == "wav"
|
||||||
|
assert (
|
||||||
|
pipeline_input.run.tts_stream.options.get(tts.ATTR_PREFERRED_SAMPLE_RATE)
|
||||||
|
== 16000
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
pipeline_input.run.tts_stream.options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS)
|
||||||
|
== 1
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio:
|
||||||
|
await pipeline_input.execute()
|
||||||
|
|
||||||
|
for event in events:
|
||||||
|
if event.type == assist_pipeline.PipelineEventType.TTS_END:
|
||||||
|
# We must fetch the media URL to trigger the TTS
|
||||||
|
assert event.data
|
||||||
|
await client.get(event.data["tts_output"]["url"])
|
||||||
|
|
||||||
|
# Ensure that no unsupported options were passed in
|
||||||
|
assert mock_get_tts_audio.called
|
||||||
|
options = mock_get_tts_audio.call_args_list[0].kwargs["options"]
|
||||||
|
extra_options = set(options).difference(mock_tts_provider.supported_options)
|
||||||
|
assert len(extra_options) == 0, extra_options
|
||||||
|
|
||||||
|
|
||||||
|
async def test_tts_wav_preferred_format(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
hass_client: ClientSessionGenerator,
|
||||||
|
mock_tts_provider: MockTTSProvider,
|
||||||
|
init_components,
|
||||||
|
mock_chat_session: chat_session.ChatSession,
|
||||||
|
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||||
|
) -> None:
|
||||||
|
"""Test that preferred format options are given to the TTS system if supported."""
|
||||||
|
client = await hass_client()
|
||||||
|
assert await async_setup_component(hass, media_source.DOMAIN, {})
|
||||||
|
|
||||||
|
events: list[assist_pipeline.PipelineEvent] = []
|
||||||
|
|
||||||
|
pipeline_store = pipeline_data.pipeline_store
|
||||||
|
pipeline_id = pipeline_store.async_get_preferred_item()
|
||||||
|
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||||
|
|
||||||
|
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||||
|
tts_input="This is a test.",
|
||||||
|
session=mock_chat_session,
|
||||||
|
device_id=None,
|
||||||
|
run=assist_pipeline.pipeline.PipelineRun(
|
||||||
|
hass,
|
||||||
|
context=Context(),
|
||||||
|
pipeline=pipeline,
|
||||||
|
start_stage=assist_pipeline.PipelineStage.TTS,
|
||||||
|
end_stage=assist_pipeline.PipelineStage.TTS,
|
||||||
|
event_callback=events.append,
|
||||||
|
tts_audio_output="wav",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await pipeline_input.validate()
|
||||||
|
|
||||||
|
# Make the TTS provider support preferred format options
|
||||||
|
supported_options = list(mock_tts_provider.supported_options or [])
|
||||||
|
supported_options.extend(
|
||||||
|
[
|
||||||
|
tts.ATTR_PREFERRED_FORMAT,
|
||||||
|
tts.ATTR_PREFERRED_SAMPLE_RATE,
|
||||||
|
tts.ATTR_PREFERRED_SAMPLE_CHANNELS,
|
||||||
|
tts.ATTR_PREFERRED_SAMPLE_BYTES,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(mock_tts_provider, "_supported_options", supported_options),
|
||||||
|
patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio,
|
||||||
|
):
|
||||||
|
await pipeline_input.execute()
|
||||||
|
|
||||||
|
for event in events:
|
||||||
|
if event.type == assist_pipeline.PipelineEventType.TTS_END:
|
||||||
|
# We must fetch the media URL to trigger the TTS
|
||||||
|
assert event.data
|
||||||
|
await client.get(event.data["tts_output"]["url"])
|
||||||
|
|
||||||
|
assert mock_get_tts_audio.called
|
||||||
|
options = mock_get_tts_audio.call_args_list[0].kwargs["options"]
|
||||||
|
|
||||||
|
# We should have received preferred format options in get_tts_audio
|
||||||
|
assert options.get(tts.ATTR_PREFERRED_FORMAT) == "wav"
|
||||||
|
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_RATE)) == 16000
|
||||||
|
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS)) == 1
|
||||||
|
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_BYTES)) == 2
|
||||||
|
|
||||||
|
|
||||||
|
async def test_tts_dict_preferred_format(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
hass_client: ClientSessionGenerator,
|
||||||
|
mock_tts_provider: MockTTSProvider,
|
||||||
|
init_components,
|
||||||
|
mock_chat_session: chat_session.ChatSession,
|
||||||
|
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||||
|
) -> None:
|
||||||
|
"""Test that preferred format options are given to the TTS system if supported."""
|
||||||
|
client = await hass_client()
|
||||||
|
assert await async_setup_component(hass, media_source.DOMAIN, {})
|
||||||
|
|
||||||
|
events: list[assist_pipeline.PipelineEvent] = []
|
||||||
|
|
||||||
|
pipeline_store = pipeline_data.pipeline_store
|
||||||
|
pipeline_id = pipeline_store.async_get_preferred_item()
|
||||||
|
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||||
|
|
||||||
|
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||||
|
tts_input="This is a test.",
|
||||||
|
session=mock_chat_session,
|
||||||
|
device_id=None,
|
||||||
|
run=assist_pipeline.pipeline.PipelineRun(
|
||||||
|
hass,
|
||||||
|
context=Context(),
|
||||||
|
pipeline=pipeline,
|
||||||
|
start_stage=assist_pipeline.PipelineStage.TTS,
|
||||||
|
end_stage=assist_pipeline.PipelineStage.TTS,
|
||||||
|
event_callback=events.append,
|
||||||
|
tts_audio_output={
|
||||||
|
tts.ATTR_PREFERRED_FORMAT: "flac",
|
||||||
|
tts.ATTR_PREFERRED_SAMPLE_RATE: 48000,
|
||||||
|
tts.ATTR_PREFERRED_SAMPLE_CHANNELS: 2,
|
||||||
|
tts.ATTR_PREFERRED_SAMPLE_BYTES: 2,
|
||||||
|
},
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await pipeline_input.validate()
|
||||||
|
|
||||||
|
# Make the TTS provider support preferred format options
|
||||||
|
supported_options = list(mock_tts_provider.supported_options or [])
|
||||||
|
supported_options.extend(
|
||||||
|
[
|
||||||
|
tts.ATTR_PREFERRED_FORMAT,
|
||||||
|
tts.ATTR_PREFERRED_SAMPLE_RATE,
|
||||||
|
tts.ATTR_PREFERRED_SAMPLE_CHANNELS,
|
||||||
|
tts.ATTR_PREFERRED_SAMPLE_BYTES,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch.object(mock_tts_provider, "_supported_options", supported_options),
|
||||||
|
patch.object(mock_tts_provider, "get_tts_audio") as mock_get_tts_audio,
|
||||||
|
):
|
||||||
|
await pipeline_input.execute()
|
||||||
|
|
||||||
|
for event in events:
|
||||||
|
if event.type == assist_pipeline.PipelineEventType.TTS_END:
|
||||||
|
# We must fetch the media URL to trigger the TTS
|
||||||
|
assert event.data
|
||||||
|
await client.get(event.data["tts_output"]["url"])
|
||||||
|
|
||||||
|
assert mock_get_tts_audio.called
|
||||||
|
options = mock_get_tts_audio.call_args_list[0].kwargs["options"]
|
||||||
|
|
||||||
|
# We should have received preferred format options in get_tts_audio
|
||||||
|
assert options.get(tts.ATTR_PREFERRED_FORMAT) == "flac"
|
||||||
|
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_RATE)) == 48000
|
||||||
|
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_CHANNELS)) == 2
|
||||||
|
assert int(options.get(tts.ATTR_PREFERRED_SAMPLE_BYTES)) == 2
|
||||||
|
|
||||||
|
|
||||||
|
async def test_sentence_trigger_overrides_conversation_agent(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_components,
|
||||||
|
mock_chat_session: chat_session.ChatSession,
|
||||||
|
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||||
|
) -> None:
|
||||||
|
"""Test that sentence triggers are checked before a non-default conversation agent."""
|
||||||
|
assert await async_setup_component(
|
||||||
|
hass,
|
||||||
|
"automation",
|
||||||
|
{
|
||||||
|
"automation": {
|
||||||
|
"trigger": {
|
||||||
|
"platform": "conversation",
|
||||||
|
"command": [
|
||||||
|
"test trigger sentence",
|
||||||
|
],
|
||||||
|
},
|
||||||
|
"action": {
|
||||||
|
"set_conversation_response": "test trigger response",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
events: list[assist_pipeline.PipelineEvent] = []
|
||||||
|
|
||||||
|
pipeline_store = pipeline_data.pipeline_store
|
||||||
|
pipeline_id = pipeline_store.async_get_preferred_item()
|
||||||
|
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||||
|
|
||||||
|
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||||
|
intent_input="test trigger sentence",
|
||||||
|
session=mock_chat_session,
|
||||||
|
run=assist_pipeline.pipeline.PipelineRun(
|
||||||
|
hass,
|
||||||
|
context=Context(),
|
||||||
|
pipeline=pipeline,
|
||||||
|
start_stage=assist_pipeline.PipelineStage.INTENT,
|
||||||
|
end_stage=assist_pipeline.PipelineStage.INTENT,
|
||||||
|
event_callback=events.append,
|
||||||
|
intent_agent="test-agent", # not the default agent
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure prepare succeeds
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
|
||||||
|
return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"),
|
||||||
|
):
|
||||||
|
await pipeline_input.validate()
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse"
|
||||||
|
) as mock_async_converse:
|
||||||
|
await pipeline_input.execute()
|
||||||
|
|
||||||
|
# Sentence trigger should have been handled
|
||||||
|
mock_async_converse.assert_not_called()
|
||||||
|
|
||||||
|
# Verify sentence trigger response
|
||||||
|
intent_end_event = next(
|
||||||
|
(
|
||||||
|
e
|
||||||
|
for e in events
|
||||||
|
if e.type == assist_pipeline.PipelineEventType.INTENT_END
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
assert (intent_end_event is not None) and intent_end_event.data
|
||||||
|
assert (
|
||||||
|
intent_end_event.data["intent_output"]["response"]["speech"]["plain"][
|
||||||
|
"speech"
|
||||||
|
]
|
||||||
|
== "test trigger response"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_prefer_local_intents(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_components,
|
||||||
|
mock_chat_session: chat_session.ChatSession,
|
||||||
|
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||||
|
) -> None:
|
||||||
|
"""Test that the default agent is checked first when local intents are preferred."""
|
||||||
|
events: list[assist_pipeline.PipelineEvent] = []
|
||||||
|
|
||||||
|
# Reuse custom sentences in test config
|
||||||
|
class OrderBeerIntentHandler(intent.IntentHandler):
|
||||||
|
intent_type = "OrderBeer"
|
||||||
|
|
||||||
|
async def async_handle(
|
||||||
|
self, intent_obj: intent.Intent
|
||||||
|
) -> intent.IntentResponse:
|
||||||
|
response = intent_obj.create_response()
|
||||||
|
response.async_set_speech("Order confirmed")
|
||||||
|
return response
|
||||||
|
|
||||||
|
handler = OrderBeerIntentHandler()
|
||||||
|
intent.async_register(hass, handler)
|
||||||
|
|
||||||
|
# Fake a test agent and prefer local intents
|
||||||
|
pipeline_store = pipeline_data.pipeline_store
|
||||||
|
pipeline_id = pipeline_store.async_get_preferred_item()
|
||||||
|
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||||
|
await assist_pipeline.pipeline.async_update_pipeline(
|
||||||
|
hass, pipeline, conversation_engine="test-agent", prefer_local_intents=True
|
||||||
|
)
|
||||||
|
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||||
|
|
||||||
|
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||||
|
intent_input="I'd like to order a stout please",
|
||||||
|
session=mock_chat_session,
|
||||||
|
run=assist_pipeline.pipeline.PipelineRun(
|
||||||
|
hass,
|
||||||
|
context=Context(),
|
||||||
|
pipeline=pipeline,
|
||||||
|
start_stage=assist_pipeline.PipelineStage.INTENT,
|
||||||
|
end_stage=assist_pipeline.PipelineStage.INTENT,
|
||||||
|
event_callback=events.append,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure prepare succeeds
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
|
||||||
|
return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"),
|
||||||
|
):
|
||||||
|
await pipeline_input.validate()
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse"
|
||||||
|
) as mock_async_converse:
|
||||||
|
await pipeline_input.execute()
|
||||||
|
|
||||||
|
# Test agent should not have been called
|
||||||
|
mock_async_converse.assert_not_called()
|
||||||
|
|
||||||
|
# Verify local intent response
|
||||||
|
intent_end_event = next(
|
||||||
|
(
|
||||||
|
e
|
||||||
|
for e in events
|
||||||
|
if e.type == assist_pipeline.PipelineEventType.INTENT_END
|
||||||
|
),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
assert (intent_end_event is not None) and intent_end_event.data
|
||||||
|
assert (
|
||||||
|
intent_end_event.data["intent_output"]["response"]["speech"]["plain"][
|
||||||
|
"speech"
|
||||||
|
]
|
||||||
|
== "Order confirmed"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_intent_continue_conversation(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
init_components,
|
||||||
|
mock_chat_session: chat_session.ChatSession,
|
||||||
|
pipeline_data: assist_pipeline.pipeline.PipelineData,
|
||||||
|
) -> None:
|
||||||
|
"""Test that a conversation agent flagging continue conversation gets response."""
|
||||||
|
events: list[assist_pipeline.PipelineEvent] = []
|
||||||
|
|
||||||
|
# Fake a test agent and prefer local intents
|
||||||
|
pipeline_store = pipeline_data.pipeline_store
|
||||||
|
pipeline_id = pipeline_store.async_get_preferred_item()
|
||||||
|
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||||
|
await assist_pipeline.pipeline.async_update_pipeline(
|
||||||
|
hass, pipeline, conversation_engine="test-agent"
|
||||||
|
)
|
||||||
|
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||||
|
|
||||||
|
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||||
|
intent_input="Set a timer",
|
||||||
|
session=mock_chat_session,
|
||||||
|
run=assist_pipeline.pipeline.PipelineRun(
|
||||||
|
hass,
|
||||||
|
context=Context(),
|
||||||
|
pipeline=pipeline,
|
||||||
|
start_stage=assist_pipeline.PipelineStage.INTENT,
|
||||||
|
end_stage=assist_pipeline.PipelineStage.INTENT,
|
||||||
|
event_callback=events.append,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure prepare succeeds
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
|
||||||
|
return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"),
|
||||||
|
):
|
||||||
|
await pipeline_input.validate()
|
||||||
|
|
||||||
|
response = intent.IntentResponse("en")
|
||||||
|
response.async_set_speech("For how long?")
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
|
||||||
|
return_value=conversation.ConversationResult(
|
||||||
|
response=response,
|
||||||
|
conversation_id=mock_chat_session.conversation_id,
|
||||||
|
continue_conversation=True,
|
||||||
|
),
|
||||||
|
) as mock_async_converse:
|
||||||
|
await pipeline_input.execute()
|
||||||
|
|
||||||
|
mock_async_converse.assert_called()
|
||||||
|
|
||||||
|
results = [
|
||||||
|
event.data
|
||||||
|
for event in events
|
||||||
|
if event.type
|
||||||
|
in (
|
||||||
|
assist_pipeline.PipelineEventType.INTENT_START,
|
||||||
|
assist_pipeline.PipelineEventType.INTENT_END,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
assert results[1]["intent_output"]["continue_conversation"] is True
|
||||||
|
|
||||||
|
# Change conversation agent to default one and register sentence trigger that should not be called
|
||||||
|
await assist_pipeline.pipeline.async_update_pipeline(
|
||||||
|
hass, pipeline, conversation_engine=None
|
||||||
|
)
|
||||||
|
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||||
|
assert await async_setup_component(
|
||||||
|
hass,
|
||||||
|
"automation",
|
||||||
|
{
|
||||||
|
"automation": {
|
||||||
|
"trigger": {
|
||||||
|
"platform": "conversation",
|
||||||
|
"command": ["Hello"],
|
||||||
|
},
|
||||||
|
"action": {
|
||||||
|
"set_conversation_response": "test trigger response",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Because we did continue conversation, it should respond to the test agent again.
|
||||||
|
events.clear()
|
||||||
|
|
||||||
|
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||||
|
intent_input="Hello",
|
||||||
|
session=mock_chat_session,
|
||||||
|
run=assist_pipeline.pipeline.PipelineRun(
|
||||||
|
hass,
|
||||||
|
context=Context(),
|
||||||
|
pipeline=pipeline,
|
||||||
|
start_stage=assist_pipeline.PipelineStage.INTENT,
|
||||||
|
end_stage=assist_pipeline.PipelineStage.INTENT,
|
||||||
|
event_callback=events.append,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Ensure prepare succeeds
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.assist_pipeline.pipeline.conversation.async_get_agent_info",
|
||||||
|
return_value=conversation.AgentInfo(id="test-agent", name="Test Agent"),
|
||||||
|
) as mock_prepare:
|
||||||
|
await pipeline_input.validate()
|
||||||
|
|
||||||
|
# It requested test agent even if that was not default agent.
|
||||||
|
assert mock_prepare.mock_calls[0][1][1] == "test-agent"
|
||||||
|
|
||||||
|
response = intent.IntentResponse("en")
|
||||||
|
response.async_set_speech("Timer set for 20 minutes")
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
|
||||||
|
return_value=conversation.ConversationResult(
|
||||||
|
response=response,
|
||||||
|
conversation_id=mock_chat_session.conversation_id,
|
||||||
|
),
|
||||||
|
) as mock_async_converse:
|
||||||
|
await pipeline_input.execute()
|
||||||
|
|
||||||
|
mock_async_converse.assert_called()
|
||||||
|
|
||||||
|
# Snapshot will show it was still handled by the test agent and not default agent
|
||||||
|
results = [
|
||||||
|
event.data
|
||||||
|
for event in events
|
||||||
|
if event.type
|
||||||
|
in (
|
||||||
|
assist_pipeline.PipelineEventType.INTENT_START,
|
||||||
|
assist_pipeline.PipelineEventType.INTENT_END,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
assert results[0]["engine"] == "test-agent"
|
||||||
|
assert results[1]["intent_output"]["continue_conversation"] is False
|
||||||
|
|
||||||
|
|
||||||
|
async def test_stt_language_used_instead_of_conversation_language(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
init_components,
|
||||||
|
mock_chat_session: chat_session.ChatSession,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
|
) -> None:
|
||||||
|
"""Test that the STT language is used first when the conversation language is '*' (all languages)."""
|
||||||
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
events: list[assist_pipeline.PipelineEvent] = []
|
||||||
|
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/pipeline/create",
|
||||||
|
"conversation_engine": "homeassistant",
|
||||||
|
"conversation_language": MATCH_ALL,
|
||||||
|
"language": "en",
|
||||||
|
"name": "test_name",
|
||||||
|
"stt_engine": "test",
|
||||||
|
"stt_language": "en-US",
|
||||||
|
"tts_engine": "test",
|
||||||
|
"tts_language": "en-US",
|
||||||
|
"tts_voice": "Arnold Schwarzenegger",
|
||||||
|
"wake_word_entity": None,
|
||||||
|
"wake_word_id": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
pipeline_id = msg["result"]["id"]
|
||||||
|
pipeline = assist_pipeline.async_get_pipeline(hass, pipeline_id)
|
||||||
|
|
||||||
|
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||||
|
intent_input="test input",
|
||||||
|
session=mock_chat_session,
|
||||||
|
run=assist_pipeline.pipeline.PipelineRun(
|
||||||
|
hass,
|
||||||
|
context=Context(),
|
||||||
|
pipeline=pipeline,
|
||||||
|
start_stage=assist_pipeline.PipelineStage.INTENT,
|
||||||
|
end_stage=assist_pipeline.PipelineStage.INTENT,
|
||||||
|
event_callback=events.append,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await pipeline_input.validate()
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
|
||||||
|
return_value=conversation.ConversationResult(
|
||||||
|
intent.IntentResponse(pipeline.language)
|
||||||
|
),
|
||||||
|
) as mock_async_converse:
|
||||||
|
await pipeline_input.execute()
|
||||||
|
|
||||||
|
# Check intent start event
|
||||||
|
assert process_events(events) == snapshot
|
||||||
|
intent_start: assist_pipeline.PipelineEvent | None = None
|
||||||
|
for event in events:
|
||||||
|
if event.type == assist_pipeline.PipelineEventType.INTENT_START:
|
||||||
|
intent_start = event
|
||||||
|
break
|
||||||
|
|
||||||
|
assert intent_start is not None
|
||||||
|
|
||||||
|
# STT language (en-US) should be used instead of '*'
|
||||||
|
assert intent_start.data.get("language") == pipeline.stt_language
|
||||||
|
|
||||||
|
# Check input to async_converse
|
||||||
|
mock_async_converse.assert_called_once()
|
||||||
|
assert (
|
||||||
|
mock_async_converse.call_args_list[0].kwargs.get("language")
|
||||||
|
== pipeline.stt_language
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_tts_language_used_instead_of_conversation_language(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
init_components,
|
||||||
|
mock_chat_session: chat_session.ChatSession,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
|
) -> None:
|
||||||
|
"""Test that the TTS language is used after STT when the conversation language is '*' (all languages)."""
|
||||||
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
events: list[assist_pipeline.PipelineEvent] = []
|
||||||
|
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/pipeline/create",
|
||||||
|
"conversation_engine": "homeassistant",
|
||||||
|
"conversation_language": MATCH_ALL,
|
||||||
|
"language": "en",
|
||||||
|
"name": "test_name",
|
||||||
|
"stt_engine": None,
|
||||||
|
"stt_language": None,
|
||||||
|
"tts_engine": None,
|
||||||
|
"tts_language": "en-us",
|
||||||
|
"tts_voice": "Arnold Schwarzenegger",
|
||||||
|
"wake_word_entity": None,
|
||||||
|
"wake_word_id": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
pipeline_id = msg["result"]["id"]
|
||||||
|
pipeline = assist_pipeline.async_get_pipeline(hass, pipeline_id)
|
||||||
|
|
||||||
|
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||||
|
intent_input="test input",
|
||||||
|
session=mock_chat_session,
|
||||||
|
run=assist_pipeline.pipeline.PipelineRun(
|
||||||
|
hass,
|
||||||
|
context=Context(),
|
||||||
|
pipeline=pipeline,
|
||||||
|
start_stage=assist_pipeline.PipelineStage.INTENT,
|
||||||
|
end_stage=assist_pipeline.PipelineStage.INTENT,
|
||||||
|
event_callback=events.append,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await pipeline_input.validate()
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
|
||||||
|
return_value=conversation.ConversationResult(
|
||||||
|
intent.IntentResponse(pipeline.language)
|
||||||
|
),
|
||||||
|
) as mock_async_converse:
|
||||||
|
await pipeline_input.execute()
|
||||||
|
|
||||||
|
# Check intent start event
|
||||||
|
assert process_events(events) == snapshot
|
||||||
|
intent_start: assist_pipeline.PipelineEvent | None = None
|
||||||
|
for event in events:
|
||||||
|
if event.type == assist_pipeline.PipelineEventType.INTENT_START:
|
||||||
|
intent_start = event
|
||||||
|
break
|
||||||
|
|
||||||
|
assert intent_start is not None
|
||||||
|
|
||||||
|
# STT language (en-US) should be used instead of '*'
|
||||||
|
assert intent_start.data.get("language") == pipeline.tts_language
|
||||||
|
|
||||||
|
# Check input to async_converse
|
||||||
|
mock_async_converse.assert_called_once()
|
||||||
|
assert (
|
||||||
|
mock_async_converse.call_args_list[0].kwargs.get("language")
|
||||||
|
== pipeline.tts_language
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_pipeline_language_used_instead_of_conversation_language(
|
||||||
|
hass: HomeAssistant,
|
||||||
|
hass_ws_client: WebSocketGenerator,
|
||||||
|
init_components,
|
||||||
|
mock_chat_session: chat_session.ChatSession,
|
||||||
|
snapshot: SnapshotAssertion,
|
||||||
|
) -> None:
|
||||||
|
"""Test that the pipeline language is used last when the conversation language is '*' (all languages)."""
|
||||||
|
client = await hass_ws_client(hass)
|
||||||
|
|
||||||
|
events: list[assist_pipeline.PipelineEvent] = []
|
||||||
|
|
||||||
|
await client.send_json_auto_id(
|
||||||
|
{
|
||||||
|
"type": "assist_pipeline/pipeline/create",
|
||||||
|
"conversation_engine": "homeassistant",
|
||||||
|
"conversation_language": MATCH_ALL,
|
||||||
|
"language": "en",
|
||||||
|
"name": "test_name",
|
||||||
|
"stt_engine": None,
|
||||||
|
"stt_language": None,
|
||||||
|
"tts_engine": None,
|
||||||
|
"tts_language": None,
|
||||||
|
"tts_voice": None,
|
||||||
|
"wake_word_entity": None,
|
||||||
|
"wake_word_id": None,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
msg = await client.receive_json()
|
||||||
|
assert msg["success"]
|
||||||
|
pipeline_id = msg["result"]["id"]
|
||||||
|
pipeline = assist_pipeline.async_get_pipeline(hass, pipeline_id)
|
||||||
|
|
||||||
|
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||||
|
intent_input="test input",
|
||||||
|
session=mock_chat_session,
|
||||||
|
run=assist_pipeline.pipeline.PipelineRun(
|
||||||
|
hass,
|
||||||
|
context=Context(),
|
||||||
|
pipeline=pipeline,
|
||||||
|
start_stage=assist_pipeline.PipelineStage.INTENT,
|
||||||
|
end_stage=assist_pipeline.PipelineStage.INTENT,
|
||||||
|
event_callback=events.append,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
await pipeline_input.validate()
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.assist_pipeline.pipeline.conversation.async_converse",
|
||||||
|
return_value=conversation.ConversationResult(
|
||||||
|
intent.IntentResponse(pipeline.language)
|
||||||
|
),
|
||||||
|
) as mock_async_converse:
|
||||||
|
await pipeline_input.execute()
|
||||||
|
|
||||||
|
# Check intent start event
|
||||||
|
assert process_events(events) == snapshot
|
||||||
|
intent_start: assist_pipeline.PipelineEvent | None = None
|
||||||
|
for event in events:
|
||||||
|
if event.type == assist_pipeline.PipelineEventType.INTENT_START:
|
||||||
|
intent_start = event
|
||||||
|
break
|
||||||
|
|
||||||
|
assert intent_start is not None
|
||||||
|
|
||||||
|
# STT language (en-US) should be used instead of '*'
|
||||||
|
assert intent_start.data.get("language") == pipeline.language
|
||||||
|
|
||||||
|
# Check input to async_converse
|
||||||
|
mock_async_converse.assert_called_once()
|
||||||
|
assert (
|
||||||
|
mock_async_converse.call_args_list[0].kwargs.get("language")
|
||||||
|
== pipeline.language
|
||||||
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user