mirror of
https://github.com/home-assistant/core.git
synced 2025-04-23 16:57:53 +00:00
Assist Pipeline to use ChatSession for conversation ID (#137143)
* Assist Pipeline to use ChatSession for conversation ID * Adjust to latest changes
This commit is contained in:
parent
8acab6c646
commit
05ca80f4ba
@ -9,6 +9,7 @@ import voluptuous as vol
|
||||
|
||||
from homeassistant.components import stt
|
||||
from homeassistant.core import Context, HomeAssistant
|
||||
from homeassistant.helpers import chat_session
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
from .const import (
|
||||
@ -114,24 +115,25 @@ async def async_pipeline_from_audio_stream(
|
||||
|
||||
Raises PipelineNotFound if no pipeline is found.
|
||||
"""
|
||||
pipeline_input = PipelineInput(
|
||||
conversation_id=conversation_id,
|
||||
device_id=device_id,
|
||||
stt_metadata=stt_metadata,
|
||||
stt_stream=stt_stream,
|
||||
wake_word_phrase=wake_word_phrase,
|
||||
conversation_extra_system_prompt=conversation_extra_system_prompt,
|
||||
run=PipelineRun(
|
||||
hass,
|
||||
context=context,
|
||||
pipeline=async_get_pipeline(hass, pipeline_id=pipeline_id),
|
||||
start_stage=start_stage,
|
||||
end_stage=end_stage,
|
||||
event_callback=event_callback,
|
||||
tts_audio_output=tts_audio_output,
|
||||
wake_word_settings=wake_word_settings,
|
||||
audio_settings=audio_settings or AudioSettings(),
|
||||
),
|
||||
)
|
||||
await pipeline_input.validate()
|
||||
await pipeline_input.execute()
|
||||
with chat_session.async_get_chat_session(hass, conversation_id) as session:
|
||||
pipeline_input = PipelineInput(
|
||||
conversation_id=session.conversation_id,
|
||||
device_id=device_id,
|
||||
stt_metadata=stt_metadata,
|
||||
stt_stream=stt_stream,
|
||||
wake_word_phrase=wake_word_phrase,
|
||||
conversation_extra_system_prompt=conversation_extra_system_prompt,
|
||||
run=PipelineRun(
|
||||
hass,
|
||||
context=context,
|
||||
pipeline=async_get_pipeline(hass, pipeline_id=pipeline_id),
|
||||
start_stage=start_stage,
|
||||
end_stage=end_stage,
|
||||
event_callback=event_callback,
|
||||
tts_audio_output=tts_audio_output,
|
||||
wake_word_settings=wake_word_settings,
|
||||
audio_settings=audio_settings or AudioSettings(),
|
||||
),
|
||||
)
|
||||
await pipeline_input.validate()
|
||||
await pipeline_input.execute()
|
||||
|
@ -624,7 +624,7 @@ class PipelineRun:
|
||||
return
|
||||
pipeline_data.pipeline_debug[self.pipeline.id][self.id].events.append(event)
|
||||
|
||||
def start(self, device_id: str | None) -> None:
|
||||
def start(self, conversation_id: str, device_id: str | None) -> None:
|
||||
"""Emit run start event."""
|
||||
self._device_id = device_id
|
||||
self._start_debug_recording_thread()
|
||||
@ -632,6 +632,7 @@ class PipelineRun:
|
||||
data = {
|
||||
"pipeline": self.pipeline.id,
|
||||
"language": self.language,
|
||||
"conversation_id": conversation_id,
|
||||
}
|
||||
if self.runner_data is not None:
|
||||
data["runner_data"] = self.runner_data
|
||||
@ -1015,7 +1016,7 @@ class PipelineRun:
|
||||
async def recognize_intent(
|
||||
self,
|
||||
intent_input: str,
|
||||
conversation_id: str | None,
|
||||
conversation_id: str,
|
||||
device_id: str | None,
|
||||
conversation_extra_system_prompt: str | None,
|
||||
) -> str:
|
||||
@ -1409,12 +1410,15 @@ def _pipeline_debug_recording_thread_proc(
|
||||
wav_writer.close()
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(kw_only=True)
|
||||
class PipelineInput:
|
||||
"""Input to a pipeline run."""
|
||||
|
||||
run: PipelineRun
|
||||
|
||||
conversation_id: str
|
||||
"""Identifier for the conversation."""
|
||||
|
||||
stt_metadata: stt.SpeechMetadata | None = None
|
||||
"""Metadata of stt input audio. Required when start_stage = stt."""
|
||||
|
||||
@ -1430,9 +1434,6 @@ class PipelineInput:
|
||||
tts_input: str | None = None
|
||||
"""Input for text-to-speech. Required when start_stage = tts."""
|
||||
|
||||
conversation_id: str | None = None
|
||||
"""Identifier for the conversation."""
|
||||
|
||||
conversation_extra_system_prompt: str | None = None
|
||||
"""Extra prompt information for the conversation agent."""
|
||||
|
||||
@ -1441,7 +1442,7 @@ class PipelineInput:
|
||||
|
||||
async def execute(self) -> None:
|
||||
"""Run pipeline."""
|
||||
self.run.start(device_id=self.device_id)
|
||||
self.run.start(conversation_id=self.conversation_id, device_id=self.device_id)
|
||||
current_stage: PipelineStage | None = self.run.start_stage
|
||||
stt_audio_buffer: list[EnhancedAudioChunk] = []
|
||||
stt_processed_stream: AsyncIterable[EnhancedAudioChunk] | None = None
|
||||
|
@ -14,7 +14,11 @@ import voluptuous as vol
|
||||
from homeassistant.components import conversation, stt, tts, websocket_api
|
||||
from homeassistant.const import ATTR_DEVICE_ID, ATTR_SECONDS, MATCH_ALL
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers import config_validation as cv, entity_registry as er
|
||||
from homeassistant.helpers import (
|
||||
chat_session,
|
||||
config_validation as cv,
|
||||
entity_registry as er,
|
||||
)
|
||||
from homeassistant.util import language as language_util
|
||||
|
||||
from .const import (
|
||||
@ -145,7 +149,6 @@ async def websocket_run(
|
||||
|
||||
# Arguments to PipelineInput
|
||||
input_args: dict[str, Any] = {
|
||||
"conversation_id": msg.get("conversation_id"),
|
||||
"device_id": msg.get("device_id"),
|
||||
}
|
||||
|
||||
@ -233,38 +236,42 @@ async def websocket_run(
|
||||
audio_settings=audio_settings or AudioSettings(),
|
||||
)
|
||||
|
||||
pipeline_input = PipelineInput(**input_args)
|
||||
with chat_session.async_get_chat_session(
|
||||
hass, msg.get("conversation_id")
|
||||
) as session:
|
||||
input_args["conversation_id"] = session.conversation_id
|
||||
pipeline_input = PipelineInput(**input_args)
|
||||
|
||||
try:
|
||||
await pipeline_input.validate()
|
||||
except PipelineError as error:
|
||||
# Report more specific error when possible
|
||||
connection.send_error(msg["id"], error.code, error.message)
|
||||
return
|
||||
try:
|
||||
await pipeline_input.validate()
|
||||
except PipelineError as error:
|
||||
# Report more specific error when possible
|
||||
connection.send_error(msg["id"], error.code, error.message)
|
||||
return
|
||||
|
||||
# Confirm subscription
|
||||
connection.send_result(msg["id"])
|
||||
# Confirm subscription
|
||||
connection.send_result(msg["id"])
|
||||
|
||||
run_task = hass.async_create_task(pipeline_input.execute())
|
||||
run_task = hass.async_create_task(pipeline_input.execute())
|
||||
|
||||
# Cancel pipeline if user unsubscribes
|
||||
connection.subscriptions[msg["id"]] = run_task.cancel
|
||||
# Cancel pipeline if user unsubscribes
|
||||
connection.subscriptions[msg["id"]] = run_task.cancel
|
||||
|
||||
try:
|
||||
# Task contains a timeout
|
||||
async with asyncio.timeout(timeout):
|
||||
await run_task
|
||||
except TimeoutError:
|
||||
pipeline_input.run.process_event(
|
||||
PipelineEvent(
|
||||
PipelineEventType.ERROR,
|
||||
{"code": "timeout", "message": "Timeout running pipeline"},
|
||||
try:
|
||||
# Task contains a timeout
|
||||
async with asyncio.timeout(timeout):
|
||||
await run_task
|
||||
except TimeoutError:
|
||||
pipeline_input.run.process_event(
|
||||
PipelineEvent(
|
||||
PipelineEventType.ERROR,
|
||||
{"code": "timeout", "message": "Timeout running pipeline"},
|
||||
)
|
||||
)
|
||||
)
|
||||
finally:
|
||||
if unregister_handler is not None:
|
||||
# Unregister binary handler
|
||||
unregister_handler()
|
||||
finally:
|
||||
if unregister_handler is not None:
|
||||
# Unregister binary handler
|
||||
unregister_handler()
|
||||
|
||||
|
||||
@callback
|
||||
|
@ -3,6 +3,7 @@
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
}),
|
||||
@ -32,7 +33,7 @@
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': 'mock-ulid',
|
||||
'device_id': None,
|
||||
'engine': 'conversation.home_assistant',
|
||||
'intent_input': 'test transcript',
|
||||
@ -94,6 +95,7 @@
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
}),
|
||||
@ -123,7 +125,7 @@
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': 'mock-ulid',
|
||||
'device_id': None,
|
||||
'engine': 'conversation.home_assistant',
|
||||
'intent_input': 'test transcript',
|
||||
@ -185,6 +187,7 @@
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
}),
|
||||
@ -214,7 +217,7 @@
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': 'mock-ulid',
|
||||
'device_id': None,
|
||||
'engine': 'conversation.home_assistant',
|
||||
'intent_input': 'test transcript',
|
||||
@ -276,6 +279,7 @@
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
}),
|
||||
@ -329,7 +333,7 @@
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': 'mock-ulid',
|
||||
'device_id': None,
|
||||
'engine': 'conversation.home_assistant',
|
||||
'intent_input': 'test transcript',
|
||||
@ -391,6 +395,7 @@
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
}),
|
||||
@ -427,6 +432,7 @@
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': 'mock-conversation-id',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
}),
|
||||
@ -434,7 +440,7 @@
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': 'mock-conversation-id',
|
||||
'device_id': None,
|
||||
'engine': 'conversation.home_assistant',
|
||||
'intent_input': 'test input',
|
||||
@ -478,6 +484,7 @@
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': 'mock-conversation-id',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
}),
|
||||
@ -485,7 +492,7 @@
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': 'mock-conversation-id',
|
||||
'device_id': None,
|
||||
'engine': 'conversation.home_assistant',
|
||||
'intent_input': 'test input',
|
||||
@ -529,6 +536,7 @@
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': 'mock-conversation-id',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
}),
|
||||
@ -536,7 +544,7 @@
|
||||
}),
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': 'mock-conversation-id',
|
||||
'device_id': None,
|
||||
'engine': 'conversation.home_assistant',
|
||||
'intent_input': 'test input',
|
||||
@ -580,6 +588,7 @@
|
||||
list([
|
||||
dict({
|
||||
'data': dict({
|
||||
'conversation_id': 'mock-conversation-id',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
}),
|
||||
|
@ -1,6 +1,7 @@
|
||||
# serializer version: 1
|
||||
# name: test_audio_pipeline
|
||||
dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
@ -31,7 +32,7 @@
|
||||
# ---
|
||||
# name: test_audio_pipeline.3
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': 'mock-ulid',
|
||||
'device_id': None,
|
||||
'engine': 'conversation.home_assistant',
|
||||
'intent_input': 'test transcript',
|
||||
@ -84,6 +85,7 @@
|
||||
# ---
|
||||
# name: test_audio_pipeline_debug
|
||||
dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
@ -114,7 +116,7 @@
|
||||
# ---
|
||||
# name: test_audio_pipeline_debug.3
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': 'mock-ulid',
|
||||
'device_id': None,
|
||||
'engine': 'conversation.home_assistant',
|
||||
'intent_input': 'test transcript',
|
||||
@ -179,6 +181,7 @@
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_enhancements
|
||||
dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
@ -209,7 +212,7 @@
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_enhancements.3
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': 'mock-ulid',
|
||||
'device_id': None,
|
||||
'engine': 'conversation.home_assistant',
|
||||
'intent_input': 'test transcript',
|
||||
@ -262,6 +265,7 @@
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_wake_word_no_timeout
|
||||
dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
@ -314,7 +318,7 @@
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_wake_word_no_timeout.5
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': 'mock-ulid',
|
||||
'device_id': None,
|
||||
'engine': 'conversation.home_assistant',
|
||||
'intent_input': 'test transcript',
|
||||
@ -367,6 +371,7 @@
|
||||
# ---
|
||||
# name: test_audio_pipeline_with_wake_word_timeout
|
||||
dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
@ -399,6 +404,7 @@
|
||||
# ---
|
||||
# name: test_device_capture
|
||||
dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
@ -425,6 +431,7 @@
|
||||
# ---
|
||||
# name: test_device_capture_override
|
||||
dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
@ -473,6 +480,7 @@
|
||||
# ---
|
||||
# name: test_device_capture_queue_full
|
||||
dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
@ -512,6 +520,7 @@
|
||||
# ---
|
||||
# name: test_intent_failed
|
||||
dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
@ -522,7 +531,7 @@
|
||||
# ---
|
||||
# name: test_intent_failed.1
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': 'mock-ulid',
|
||||
'device_id': None,
|
||||
'engine': 'conversation.home_assistant',
|
||||
'intent_input': 'Are the lights on?',
|
||||
@ -535,6 +544,7 @@
|
||||
# ---
|
||||
# name: test_intent_timeout
|
||||
dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
@ -545,7 +555,7 @@
|
||||
# ---
|
||||
# name: test_intent_timeout.1
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': 'mock-ulid',
|
||||
'device_id': None,
|
||||
'engine': 'conversation.home_assistant',
|
||||
'intent_input': 'Are the lights on?',
|
||||
@ -564,6 +574,7 @@
|
||||
# ---
|
||||
# name: test_pipeline_empty_tts_output
|
||||
dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
@ -574,7 +585,7 @@
|
||||
# ---
|
||||
# name: test_pipeline_empty_tts_output.1
|
||||
dict({
|
||||
'conversation_id': None,
|
||||
'conversation_id': 'mock-ulid',
|
||||
'device_id': None,
|
||||
'engine': 'conversation.home_assistant',
|
||||
'intent_input': 'never mind',
|
||||
@ -611,6 +622,7 @@
|
||||
# ---
|
||||
# name: test_stt_cooldown_different_ids
|
||||
dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
@ -621,6 +633,7 @@
|
||||
# ---
|
||||
# name: test_stt_cooldown_different_ids.1
|
||||
dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
@ -631,6 +644,7 @@
|
||||
# ---
|
||||
# name: test_stt_cooldown_same_id
|
||||
dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
@ -641,6 +655,7 @@
|
||||
# ---
|
||||
# name: test_stt_cooldown_same_id.1
|
||||
dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
@ -651,6 +666,7 @@
|
||||
# ---
|
||||
# name: test_stt_stream_failed
|
||||
dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
@ -677,6 +693,7 @@
|
||||
# ---
|
||||
# name: test_text_only_pipeline[extra_msg0]
|
||||
dict({
|
||||
'conversation_id': 'mock-conversation-id',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
@ -723,6 +740,7 @@
|
||||
# ---
|
||||
# name: test_text_only_pipeline[extra_msg1]
|
||||
dict({
|
||||
'conversation_id': 'mock-conversation-id',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
@ -775,6 +793,7 @@
|
||||
# ---
|
||||
# name: test_tts_failed
|
||||
dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
@ -796,6 +815,7 @@
|
||||
# ---
|
||||
# name: test_wake_word_cooldown_different_entities
|
||||
dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
@ -806,6 +826,7 @@
|
||||
# ---
|
||||
# name: test_wake_word_cooldown_different_entities.1
|
||||
dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
@ -857,6 +878,7 @@
|
||||
# ---
|
||||
# name: test_wake_word_cooldown_different_ids
|
||||
dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
@ -867,6 +889,7 @@
|
||||
# ---
|
||||
# name: test_wake_word_cooldown_different_ids.1
|
||||
dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
@ -921,6 +944,7 @@
|
||||
# ---
|
||||
# name: test_wake_word_cooldown_same_id
|
||||
dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
@ -931,6 +955,7 @@
|
||||
# ---
|
||||
# name: test_wake_word_cooldown_same_id.1
|
||||
dict({
|
||||
'conversation_id': 'mock-ulid',
|
||||
'language': 'en',
|
||||
'pipeline': <ANY>,
|
||||
'runner_data': dict({
|
||||
|
@ -1,11 +1,12 @@
|
||||
"""Test Voice Assistant init."""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Generator
|
||||
from dataclasses import asdict
|
||||
import itertools as it
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
from unittest.mock import ANY, patch
|
||||
from unittest.mock import ANY, Mock, patch
|
||||
import wave
|
||||
|
||||
import hass_nabucasa
|
||||
@ -41,6 +42,14 @@ from .conftest import (
|
||||
from tests.typing import ClientSessionGenerator, WebSocketGenerator
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_ulid() -> Generator[Mock]:
|
||||
"""Mock the ulid of chat sessions."""
|
||||
with patch("homeassistant.helpers.chat_session.ulid_now") as mock_ulid_now:
|
||||
mock_ulid_now.return_value = "mock-ulid"
|
||||
yield mock_ulid_now
|
||||
|
||||
|
||||
def process_events(events: list[assist_pipeline.PipelineEvent]) -> list[dict]:
|
||||
"""Process events to remove dynamic values."""
|
||||
processed = []
|
||||
@ -684,7 +693,7 @@ async def test_wake_word_detection_aborted(
|
||||
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
conversation_id=None,
|
||||
conversation_id="mock-conversation-id",
|
||||
device_id=None,
|
||||
stt_metadata=stt.SpeechMetadata(
|
||||
language="",
|
||||
@ -771,7 +780,7 @@ async def test_tts_audio_output(
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
tts_input="This is a test.",
|
||||
conversation_id=None,
|
||||
conversation_id="mock-conversation-id",
|
||||
device_id=None,
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
@ -828,7 +837,7 @@ async def test_tts_wav_preferred_format(
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
tts_input="This is a test.",
|
||||
conversation_id=None,
|
||||
conversation_id="mock-conversation-id",
|
||||
device_id=None,
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
@ -896,7 +905,7 @@ async def test_tts_dict_preferred_format(
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
tts_input="This is a test.",
|
||||
conversation_id=None,
|
||||
conversation_id="mock-conversation-id",
|
||||
device_id=None,
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
@ -982,6 +991,7 @@ async def test_sentence_trigger_overrides_conversation_agent(
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
intent_input="test trigger sentence",
|
||||
conversation_id="mock-conversation-id",
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
@ -1059,6 +1069,7 @@ async def test_prefer_local_intents(
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
intent_input="I'd like to order a stout please",
|
||||
conversation_id="mock-conversation-id",
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
@ -1136,6 +1147,7 @@ async def test_stt_language_used_instead_of_conversation_language(
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
intent_input="test input",
|
||||
conversation_id="mock-conversation-id",
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
@ -1210,6 +1222,7 @@ async def test_tts_language_used_instead_of_conversation_language(
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
intent_input="test input",
|
||||
conversation_id="mock-conversation-id",
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
@ -1284,6 +1297,7 @@ async def test_pipeline_language_used_instead_of_conversation_language(
|
||||
|
||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||
intent_input="test input",
|
||||
conversation_id="mock-conversation-id",
|
||||
run=assist_pipeline.pipeline.PipelineRun(
|
||||
hass,
|
||||
context=Context(),
|
||||
|
@ -2,8 +2,9 @@
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from unittest.mock import ANY, patch
|
||||
from unittest.mock import ANY, Mock, patch
|
||||
|
||||
import pytest
|
||||
from syrupy.assertion import SnapshotAssertion
|
||||
@ -35,6 +36,14 @@ from tests.common import MockConfigEntry
|
||||
from tests.typing import WebSocketGenerator
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_ulid() -> Generator[Mock]:
|
||||
"""Mock the ulid of chat sessions."""
|
||||
with patch("homeassistant.helpers.chat_session.ulid_now") as mock_ulid_now:
|
||||
mock_ulid_now.return_value = "mock-ulid"
|
||||
yield mock_ulid_now
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"extra_msg",
|
||||
[
|
||||
|
Loading…
x
Reference in New Issue
Block a user