mirror of
https://github.com/home-assistant/core.git
synced 2025-07-26 22:57:17 +00:00
Assist Pipeline to use ChatSession for conversation ID (#137143)
* Assist Pipeline to use ChatSession for conversation ID * Adjust to latest changes
This commit is contained in:
parent
8acab6c646
commit
05ca80f4ba
@ -9,6 +9,7 @@ import voluptuous as vol
|
|||||||
|
|
||||||
from homeassistant.components import stt
|
from homeassistant.components import stt
|
||||||
from homeassistant.core import Context, HomeAssistant
|
from homeassistant.core import Context, HomeAssistant
|
||||||
|
from homeassistant.helpers import chat_session
|
||||||
from homeassistant.helpers.typing import ConfigType
|
from homeassistant.helpers.typing import ConfigType
|
||||||
|
|
||||||
from .const import (
|
from .const import (
|
||||||
@ -114,8 +115,9 @@ async def async_pipeline_from_audio_stream(
|
|||||||
|
|
||||||
Raises PipelineNotFound if no pipeline is found.
|
Raises PipelineNotFound if no pipeline is found.
|
||||||
"""
|
"""
|
||||||
|
with chat_session.async_get_chat_session(hass, conversation_id) as session:
|
||||||
pipeline_input = PipelineInput(
|
pipeline_input = PipelineInput(
|
||||||
conversation_id=conversation_id,
|
conversation_id=session.conversation_id,
|
||||||
device_id=device_id,
|
device_id=device_id,
|
||||||
stt_metadata=stt_metadata,
|
stt_metadata=stt_metadata,
|
||||||
stt_stream=stt_stream,
|
stt_stream=stt_stream,
|
||||||
|
@ -624,7 +624,7 @@ class PipelineRun:
|
|||||||
return
|
return
|
||||||
pipeline_data.pipeline_debug[self.pipeline.id][self.id].events.append(event)
|
pipeline_data.pipeline_debug[self.pipeline.id][self.id].events.append(event)
|
||||||
|
|
||||||
def start(self, device_id: str | None) -> None:
|
def start(self, conversation_id: str, device_id: str | None) -> None:
|
||||||
"""Emit run start event."""
|
"""Emit run start event."""
|
||||||
self._device_id = device_id
|
self._device_id = device_id
|
||||||
self._start_debug_recording_thread()
|
self._start_debug_recording_thread()
|
||||||
@ -632,6 +632,7 @@ class PipelineRun:
|
|||||||
data = {
|
data = {
|
||||||
"pipeline": self.pipeline.id,
|
"pipeline": self.pipeline.id,
|
||||||
"language": self.language,
|
"language": self.language,
|
||||||
|
"conversation_id": conversation_id,
|
||||||
}
|
}
|
||||||
if self.runner_data is not None:
|
if self.runner_data is not None:
|
||||||
data["runner_data"] = self.runner_data
|
data["runner_data"] = self.runner_data
|
||||||
@ -1015,7 +1016,7 @@ class PipelineRun:
|
|||||||
async def recognize_intent(
|
async def recognize_intent(
|
||||||
self,
|
self,
|
||||||
intent_input: str,
|
intent_input: str,
|
||||||
conversation_id: str | None,
|
conversation_id: str,
|
||||||
device_id: str | None,
|
device_id: str | None,
|
||||||
conversation_extra_system_prompt: str | None,
|
conversation_extra_system_prompt: str | None,
|
||||||
) -> str:
|
) -> str:
|
||||||
@ -1409,12 +1410,15 @@ def _pipeline_debug_recording_thread_proc(
|
|||||||
wav_writer.close()
|
wav_writer.close()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass(kw_only=True)
|
||||||
class PipelineInput:
|
class PipelineInput:
|
||||||
"""Input to a pipeline run."""
|
"""Input to a pipeline run."""
|
||||||
|
|
||||||
run: PipelineRun
|
run: PipelineRun
|
||||||
|
|
||||||
|
conversation_id: str
|
||||||
|
"""Identifier for the conversation."""
|
||||||
|
|
||||||
stt_metadata: stt.SpeechMetadata | None = None
|
stt_metadata: stt.SpeechMetadata | None = None
|
||||||
"""Metadata of stt input audio. Required when start_stage = stt."""
|
"""Metadata of stt input audio. Required when start_stage = stt."""
|
||||||
|
|
||||||
@ -1430,9 +1434,6 @@ class PipelineInput:
|
|||||||
tts_input: str | None = None
|
tts_input: str | None = None
|
||||||
"""Input for text-to-speech. Required when start_stage = tts."""
|
"""Input for text-to-speech. Required when start_stage = tts."""
|
||||||
|
|
||||||
conversation_id: str | None = None
|
|
||||||
"""Identifier for the conversation."""
|
|
||||||
|
|
||||||
conversation_extra_system_prompt: str | None = None
|
conversation_extra_system_prompt: str | None = None
|
||||||
"""Extra prompt information for the conversation agent."""
|
"""Extra prompt information for the conversation agent."""
|
||||||
|
|
||||||
@ -1441,7 +1442,7 @@ class PipelineInput:
|
|||||||
|
|
||||||
async def execute(self) -> None:
|
async def execute(self) -> None:
|
||||||
"""Run pipeline."""
|
"""Run pipeline."""
|
||||||
self.run.start(device_id=self.device_id)
|
self.run.start(conversation_id=self.conversation_id, device_id=self.device_id)
|
||||||
current_stage: PipelineStage | None = self.run.start_stage
|
current_stage: PipelineStage | None = self.run.start_stage
|
||||||
stt_audio_buffer: list[EnhancedAudioChunk] = []
|
stt_audio_buffer: list[EnhancedAudioChunk] = []
|
||||||
stt_processed_stream: AsyncIterable[EnhancedAudioChunk] | None = None
|
stt_processed_stream: AsyncIterable[EnhancedAudioChunk] | None = None
|
||||||
|
@ -14,7 +14,11 @@ import voluptuous as vol
|
|||||||
from homeassistant.components import conversation, stt, tts, websocket_api
|
from homeassistant.components import conversation, stt, tts, websocket_api
|
||||||
from homeassistant.const import ATTR_DEVICE_ID, ATTR_SECONDS, MATCH_ALL
|
from homeassistant.const import ATTR_DEVICE_ID, ATTR_SECONDS, MATCH_ALL
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.helpers import config_validation as cv, entity_registry as er
|
from homeassistant.helpers import (
|
||||||
|
chat_session,
|
||||||
|
config_validation as cv,
|
||||||
|
entity_registry as er,
|
||||||
|
)
|
||||||
from homeassistant.util import language as language_util
|
from homeassistant.util import language as language_util
|
||||||
|
|
||||||
from .const import (
|
from .const import (
|
||||||
@ -145,7 +149,6 @@ async def websocket_run(
|
|||||||
|
|
||||||
# Arguments to PipelineInput
|
# Arguments to PipelineInput
|
||||||
input_args: dict[str, Any] = {
|
input_args: dict[str, Any] = {
|
||||||
"conversation_id": msg.get("conversation_id"),
|
|
||||||
"device_id": msg.get("device_id"),
|
"device_id": msg.get("device_id"),
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -233,6 +236,10 @@ async def websocket_run(
|
|||||||
audio_settings=audio_settings or AudioSettings(),
|
audio_settings=audio_settings or AudioSettings(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
with chat_session.async_get_chat_session(
|
||||||
|
hass, msg.get("conversation_id")
|
||||||
|
) as session:
|
||||||
|
input_args["conversation_id"] = session.conversation_id
|
||||||
pipeline_input = PipelineInput(**input_args)
|
pipeline_input = PipelineInput(**input_args)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
list([
|
list([
|
||||||
dict({
|
dict({
|
||||||
'data': dict({
|
'data': dict({
|
||||||
|
'conversation_id': 'mock-ulid',
|
||||||
'language': 'en',
|
'language': 'en',
|
||||||
'pipeline': <ANY>,
|
'pipeline': <ANY>,
|
||||||
}),
|
}),
|
||||||
@ -32,7 +33,7 @@
|
|||||||
}),
|
}),
|
||||||
dict({
|
dict({
|
||||||
'data': dict({
|
'data': dict({
|
||||||
'conversation_id': None,
|
'conversation_id': 'mock-ulid',
|
||||||
'device_id': None,
|
'device_id': None,
|
||||||
'engine': 'conversation.home_assistant',
|
'engine': 'conversation.home_assistant',
|
||||||
'intent_input': 'test transcript',
|
'intent_input': 'test transcript',
|
||||||
@ -94,6 +95,7 @@
|
|||||||
list([
|
list([
|
||||||
dict({
|
dict({
|
||||||
'data': dict({
|
'data': dict({
|
||||||
|
'conversation_id': 'mock-ulid',
|
||||||
'language': 'en',
|
'language': 'en',
|
||||||
'pipeline': <ANY>,
|
'pipeline': <ANY>,
|
||||||
}),
|
}),
|
||||||
@ -123,7 +125,7 @@
|
|||||||
}),
|
}),
|
||||||
dict({
|
dict({
|
||||||
'data': dict({
|
'data': dict({
|
||||||
'conversation_id': None,
|
'conversation_id': 'mock-ulid',
|
||||||
'device_id': None,
|
'device_id': None,
|
||||||
'engine': 'conversation.home_assistant',
|
'engine': 'conversation.home_assistant',
|
||||||
'intent_input': 'test transcript',
|
'intent_input': 'test transcript',
|
||||||
@ -185,6 +187,7 @@
|
|||||||
list([
|
list([
|
||||||
dict({
|
dict({
|
||||||
'data': dict({
|
'data': dict({
|
||||||
|
'conversation_id': 'mock-ulid',
|
||||||
'language': 'en',
|
'language': 'en',
|
||||||
'pipeline': <ANY>,
|
'pipeline': <ANY>,
|
||||||
}),
|
}),
|
||||||
@ -214,7 +217,7 @@
|
|||||||
}),
|
}),
|
||||||
dict({
|
dict({
|
||||||
'data': dict({
|
'data': dict({
|
||||||
'conversation_id': None,
|
'conversation_id': 'mock-ulid',
|
||||||
'device_id': None,
|
'device_id': None,
|
||||||
'engine': 'conversation.home_assistant',
|
'engine': 'conversation.home_assistant',
|
||||||
'intent_input': 'test transcript',
|
'intent_input': 'test transcript',
|
||||||
@ -276,6 +279,7 @@
|
|||||||
list([
|
list([
|
||||||
dict({
|
dict({
|
||||||
'data': dict({
|
'data': dict({
|
||||||
|
'conversation_id': 'mock-ulid',
|
||||||
'language': 'en',
|
'language': 'en',
|
||||||
'pipeline': <ANY>,
|
'pipeline': <ANY>,
|
||||||
}),
|
}),
|
||||||
@ -329,7 +333,7 @@
|
|||||||
}),
|
}),
|
||||||
dict({
|
dict({
|
||||||
'data': dict({
|
'data': dict({
|
||||||
'conversation_id': None,
|
'conversation_id': 'mock-ulid',
|
||||||
'device_id': None,
|
'device_id': None,
|
||||||
'engine': 'conversation.home_assistant',
|
'engine': 'conversation.home_assistant',
|
||||||
'intent_input': 'test transcript',
|
'intent_input': 'test transcript',
|
||||||
@ -391,6 +395,7 @@
|
|||||||
list([
|
list([
|
||||||
dict({
|
dict({
|
||||||
'data': dict({
|
'data': dict({
|
||||||
|
'conversation_id': 'mock-ulid',
|
||||||
'language': 'en',
|
'language': 'en',
|
||||||
'pipeline': <ANY>,
|
'pipeline': <ANY>,
|
||||||
}),
|
}),
|
||||||
@ -427,6 +432,7 @@
|
|||||||
list([
|
list([
|
||||||
dict({
|
dict({
|
||||||
'data': dict({
|
'data': dict({
|
||||||
|
'conversation_id': 'mock-conversation-id',
|
||||||
'language': 'en',
|
'language': 'en',
|
||||||
'pipeline': <ANY>,
|
'pipeline': <ANY>,
|
||||||
}),
|
}),
|
||||||
@ -434,7 +440,7 @@
|
|||||||
}),
|
}),
|
||||||
dict({
|
dict({
|
||||||
'data': dict({
|
'data': dict({
|
||||||
'conversation_id': None,
|
'conversation_id': 'mock-conversation-id',
|
||||||
'device_id': None,
|
'device_id': None,
|
||||||
'engine': 'conversation.home_assistant',
|
'engine': 'conversation.home_assistant',
|
||||||
'intent_input': 'test input',
|
'intent_input': 'test input',
|
||||||
@ -478,6 +484,7 @@
|
|||||||
list([
|
list([
|
||||||
dict({
|
dict({
|
||||||
'data': dict({
|
'data': dict({
|
||||||
|
'conversation_id': 'mock-conversation-id',
|
||||||
'language': 'en',
|
'language': 'en',
|
||||||
'pipeline': <ANY>,
|
'pipeline': <ANY>,
|
||||||
}),
|
}),
|
||||||
@ -485,7 +492,7 @@
|
|||||||
}),
|
}),
|
||||||
dict({
|
dict({
|
||||||
'data': dict({
|
'data': dict({
|
||||||
'conversation_id': None,
|
'conversation_id': 'mock-conversation-id',
|
||||||
'device_id': None,
|
'device_id': None,
|
||||||
'engine': 'conversation.home_assistant',
|
'engine': 'conversation.home_assistant',
|
||||||
'intent_input': 'test input',
|
'intent_input': 'test input',
|
||||||
@ -529,6 +536,7 @@
|
|||||||
list([
|
list([
|
||||||
dict({
|
dict({
|
||||||
'data': dict({
|
'data': dict({
|
||||||
|
'conversation_id': 'mock-conversation-id',
|
||||||
'language': 'en',
|
'language': 'en',
|
||||||
'pipeline': <ANY>,
|
'pipeline': <ANY>,
|
||||||
}),
|
}),
|
||||||
@ -536,7 +544,7 @@
|
|||||||
}),
|
}),
|
||||||
dict({
|
dict({
|
||||||
'data': dict({
|
'data': dict({
|
||||||
'conversation_id': None,
|
'conversation_id': 'mock-conversation-id',
|
||||||
'device_id': None,
|
'device_id': None,
|
||||||
'engine': 'conversation.home_assistant',
|
'engine': 'conversation.home_assistant',
|
||||||
'intent_input': 'test input',
|
'intent_input': 'test input',
|
||||||
@ -580,6 +588,7 @@
|
|||||||
list([
|
list([
|
||||||
dict({
|
dict({
|
||||||
'data': dict({
|
'data': dict({
|
||||||
|
'conversation_id': 'mock-conversation-id',
|
||||||
'language': 'en',
|
'language': 'en',
|
||||||
'pipeline': <ANY>,
|
'pipeline': <ANY>,
|
||||||
}),
|
}),
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# serializer version: 1
|
# serializer version: 1
|
||||||
# name: test_audio_pipeline
|
# name: test_audio_pipeline
|
||||||
dict({
|
dict({
|
||||||
|
'conversation_id': 'mock-ulid',
|
||||||
'language': 'en',
|
'language': 'en',
|
||||||
'pipeline': <ANY>,
|
'pipeline': <ANY>,
|
||||||
'runner_data': dict({
|
'runner_data': dict({
|
||||||
@ -31,7 +32,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_audio_pipeline.3
|
# name: test_audio_pipeline.3
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': 'mock-ulid',
|
||||||
'device_id': None,
|
'device_id': None,
|
||||||
'engine': 'conversation.home_assistant',
|
'engine': 'conversation.home_assistant',
|
||||||
'intent_input': 'test transcript',
|
'intent_input': 'test transcript',
|
||||||
@ -84,6 +85,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_audio_pipeline_debug
|
# name: test_audio_pipeline_debug
|
||||||
dict({
|
dict({
|
||||||
|
'conversation_id': 'mock-ulid',
|
||||||
'language': 'en',
|
'language': 'en',
|
||||||
'pipeline': <ANY>,
|
'pipeline': <ANY>,
|
||||||
'runner_data': dict({
|
'runner_data': dict({
|
||||||
@ -114,7 +116,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_audio_pipeline_debug.3
|
# name: test_audio_pipeline_debug.3
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': 'mock-ulid',
|
||||||
'device_id': None,
|
'device_id': None,
|
||||||
'engine': 'conversation.home_assistant',
|
'engine': 'conversation.home_assistant',
|
||||||
'intent_input': 'test transcript',
|
'intent_input': 'test transcript',
|
||||||
@ -179,6 +181,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_audio_pipeline_with_enhancements
|
# name: test_audio_pipeline_with_enhancements
|
||||||
dict({
|
dict({
|
||||||
|
'conversation_id': 'mock-ulid',
|
||||||
'language': 'en',
|
'language': 'en',
|
||||||
'pipeline': <ANY>,
|
'pipeline': <ANY>,
|
||||||
'runner_data': dict({
|
'runner_data': dict({
|
||||||
@ -209,7 +212,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_audio_pipeline_with_enhancements.3
|
# name: test_audio_pipeline_with_enhancements.3
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': 'mock-ulid',
|
||||||
'device_id': None,
|
'device_id': None,
|
||||||
'engine': 'conversation.home_assistant',
|
'engine': 'conversation.home_assistant',
|
||||||
'intent_input': 'test transcript',
|
'intent_input': 'test transcript',
|
||||||
@ -262,6 +265,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_audio_pipeline_with_wake_word_no_timeout
|
# name: test_audio_pipeline_with_wake_word_no_timeout
|
||||||
dict({
|
dict({
|
||||||
|
'conversation_id': 'mock-ulid',
|
||||||
'language': 'en',
|
'language': 'en',
|
||||||
'pipeline': <ANY>,
|
'pipeline': <ANY>,
|
||||||
'runner_data': dict({
|
'runner_data': dict({
|
||||||
@ -314,7 +318,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_audio_pipeline_with_wake_word_no_timeout.5
|
# name: test_audio_pipeline_with_wake_word_no_timeout.5
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': 'mock-ulid',
|
||||||
'device_id': None,
|
'device_id': None,
|
||||||
'engine': 'conversation.home_assistant',
|
'engine': 'conversation.home_assistant',
|
||||||
'intent_input': 'test transcript',
|
'intent_input': 'test transcript',
|
||||||
@ -367,6 +371,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_audio_pipeline_with_wake_word_timeout
|
# name: test_audio_pipeline_with_wake_word_timeout
|
||||||
dict({
|
dict({
|
||||||
|
'conversation_id': 'mock-ulid',
|
||||||
'language': 'en',
|
'language': 'en',
|
||||||
'pipeline': <ANY>,
|
'pipeline': <ANY>,
|
||||||
'runner_data': dict({
|
'runner_data': dict({
|
||||||
@ -399,6 +404,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_device_capture
|
# name: test_device_capture
|
||||||
dict({
|
dict({
|
||||||
|
'conversation_id': 'mock-ulid',
|
||||||
'language': 'en',
|
'language': 'en',
|
||||||
'pipeline': <ANY>,
|
'pipeline': <ANY>,
|
||||||
'runner_data': dict({
|
'runner_data': dict({
|
||||||
@ -425,6 +431,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_device_capture_override
|
# name: test_device_capture_override
|
||||||
dict({
|
dict({
|
||||||
|
'conversation_id': 'mock-ulid',
|
||||||
'language': 'en',
|
'language': 'en',
|
||||||
'pipeline': <ANY>,
|
'pipeline': <ANY>,
|
||||||
'runner_data': dict({
|
'runner_data': dict({
|
||||||
@ -473,6 +480,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_device_capture_queue_full
|
# name: test_device_capture_queue_full
|
||||||
dict({
|
dict({
|
||||||
|
'conversation_id': 'mock-ulid',
|
||||||
'language': 'en',
|
'language': 'en',
|
||||||
'pipeline': <ANY>,
|
'pipeline': <ANY>,
|
||||||
'runner_data': dict({
|
'runner_data': dict({
|
||||||
@ -512,6 +520,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_intent_failed
|
# name: test_intent_failed
|
||||||
dict({
|
dict({
|
||||||
|
'conversation_id': 'mock-ulid',
|
||||||
'language': 'en',
|
'language': 'en',
|
||||||
'pipeline': <ANY>,
|
'pipeline': <ANY>,
|
||||||
'runner_data': dict({
|
'runner_data': dict({
|
||||||
@ -522,7 +531,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_intent_failed.1
|
# name: test_intent_failed.1
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': 'mock-ulid',
|
||||||
'device_id': None,
|
'device_id': None,
|
||||||
'engine': 'conversation.home_assistant',
|
'engine': 'conversation.home_assistant',
|
||||||
'intent_input': 'Are the lights on?',
|
'intent_input': 'Are the lights on?',
|
||||||
@ -535,6 +544,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_intent_timeout
|
# name: test_intent_timeout
|
||||||
dict({
|
dict({
|
||||||
|
'conversation_id': 'mock-ulid',
|
||||||
'language': 'en',
|
'language': 'en',
|
||||||
'pipeline': <ANY>,
|
'pipeline': <ANY>,
|
||||||
'runner_data': dict({
|
'runner_data': dict({
|
||||||
@ -545,7 +555,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_intent_timeout.1
|
# name: test_intent_timeout.1
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': 'mock-ulid',
|
||||||
'device_id': None,
|
'device_id': None,
|
||||||
'engine': 'conversation.home_assistant',
|
'engine': 'conversation.home_assistant',
|
||||||
'intent_input': 'Are the lights on?',
|
'intent_input': 'Are the lights on?',
|
||||||
@ -564,6 +574,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_pipeline_empty_tts_output
|
# name: test_pipeline_empty_tts_output
|
||||||
dict({
|
dict({
|
||||||
|
'conversation_id': 'mock-ulid',
|
||||||
'language': 'en',
|
'language': 'en',
|
||||||
'pipeline': <ANY>,
|
'pipeline': <ANY>,
|
||||||
'runner_data': dict({
|
'runner_data': dict({
|
||||||
@ -574,7 +585,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_pipeline_empty_tts_output.1
|
# name: test_pipeline_empty_tts_output.1
|
||||||
dict({
|
dict({
|
||||||
'conversation_id': None,
|
'conversation_id': 'mock-ulid',
|
||||||
'device_id': None,
|
'device_id': None,
|
||||||
'engine': 'conversation.home_assistant',
|
'engine': 'conversation.home_assistant',
|
||||||
'intent_input': 'never mind',
|
'intent_input': 'never mind',
|
||||||
@ -611,6 +622,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_stt_cooldown_different_ids
|
# name: test_stt_cooldown_different_ids
|
||||||
dict({
|
dict({
|
||||||
|
'conversation_id': 'mock-ulid',
|
||||||
'language': 'en',
|
'language': 'en',
|
||||||
'pipeline': <ANY>,
|
'pipeline': <ANY>,
|
||||||
'runner_data': dict({
|
'runner_data': dict({
|
||||||
@ -621,6 +633,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_stt_cooldown_different_ids.1
|
# name: test_stt_cooldown_different_ids.1
|
||||||
dict({
|
dict({
|
||||||
|
'conversation_id': 'mock-ulid',
|
||||||
'language': 'en',
|
'language': 'en',
|
||||||
'pipeline': <ANY>,
|
'pipeline': <ANY>,
|
||||||
'runner_data': dict({
|
'runner_data': dict({
|
||||||
@ -631,6 +644,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_stt_cooldown_same_id
|
# name: test_stt_cooldown_same_id
|
||||||
dict({
|
dict({
|
||||||
|
'conversation_id': 'mock-ulid',
|
||||||
'language': 'en',
|
'language': 'en',
|
||||||
'pipeline': <ANY>,
|
'pipeline': <ANY>,
|
||||||
'runner_data': dict({
|
'runner_data': dict({
|
||||||
@ -641,6 +655,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_stt_cooldown_same_id.1
|
# name: test_stt_cooldown_same_id.1
|
||||||
dict({
|
dict({
|
||||||
|
'conversation_id': 'mock-ulid',
|
||||||
'language': 'en',
|
'language': 'en',
|
||||||
'pipeline': <ANY>,
|
'pipeline': <ANY>,
|
||||||
'runner_data': dict({
|
'runner_data': dict({
|
||||||
@ -651,6 +666,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_stt_stream_failed
|
# name: test_stt_stream_failed
|
||||||
dict({
|
dict({
|
||||||
|
'conversation_id': 'mock-ulid',
|
||||||
'language': 'en',
|
'language': 'en',
|
||||||
'pipeline': <ANY>,
|
'pipeline': <ANY>,
|
||||||
'runner_data': dict({
|
'runner_data': dict({
|
||||||
@ -677,6 +693,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_text_only_pipeline[extra_msg0]
|
# name: test_text_only_pipeline[extra_msg0]
|
||||||
dict({
|
dict({
|
||||||
|
'conversation_id': 'mock-conversation-id',
|
||||||
'language': 'en',
|
'language': 'en',
|
||||||
'pipeline': <ANY>,
|
'pipeline': <ANY>,
|
||||||
'runner_data': dict({
|
'runner_data': dict({
|
||||||
@ -723,6 +740,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_text_only_pipeline[extra_msg1]
|
# name: test_text_only_pipeline[extra_msg1]
|
||||||
dict({
|
dict({
|
||||||
|
'conversation_id': 'mock-conversation-id',
|
||||||
'language': 'en',
|
'language': 'en',
|
||||||
'pipeline': <ANY>,
|
'pipeline': <ANY>,
|
||||||
'runner_data': dict({
|
'runner_data': dict({
|
||||||
@ -775,6 +793,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_tts_failed
|
# name: test_tts_failed
|
||||||
dict({
|
dict({
|
||||||
|
'conversation_id': 'mock-ulid',
|
||||||
'language': 'en',
|
'language': 'en',
|
||||||
'pipeline': <ANY>,
|
'pipeline': <ANY>,
|
||||||
'runner_data': dict({
|
'runner_data': dict({
|
||||||
@ -796,6 +815,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_wake_word_cooldown_different_entities
|
# name: test_wake_word_cooldown_different_entities
|
||||||
dict({
|
dict({
|
||||||
|
'conversation_id': 'mock-ulid',
|
||||||
'language': 'en',
|
'language': 'en',
|
||||||
'pipeline': <ANY>,
|
'pipeline': <ANY>,
|
||||||
'runner_data': dict({
|
'runner_data': dict({
|
||||||
@ -806,6 +826,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_wake_word_cooldown_different_entities.1
|
# name: test_wake_word_cooldown_different_entities.1
|
||||||
dict({
|
dict({
|
||||||
|
'conversation_id': 'mock-ulid',
|
||||||
'language': 'en',
|
'language': 'en',
|
||||||
'pipeline': <ANY>,
|
'pipeline': <ANY>,
|
||||||
'runner_data': dict({
|
'runner_data': dict({
|
||||||
@ -857,6 +878,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_wake_word_cooldown_different_ids
|
# name: test_wake_word_cooldown_different_ids
|
||||||
dict({
|
dict({
|
||||||
|
'conversation_id': 'mock-ulid',
|
||||||
'language': 'en',
|
'language': 'en',
|
||||||
'pipeline': <ANY>,
|
'pipeline': <ANY>,
|
||||||
'runner_data': dict({
|
'runner_data': dict({
|
||||||
@ -867,6 +889,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_wake_word_cooldown_different_ids.1
|
# name: test_wake_word_cooldown_different_ids.1
|
||||||
dict({
|
dict({
|
||||||
|
'conversation_id': 'mock-ulid',
|
||||||
'language': 'en',
|
'language': 'en',
|
||||||
'pipeline': <ANY>,
|
'pipeline': <ANY>,
|
||||||
'runner_data': dict({
|
'runner_data': dict({
|
||||||
@ -921,6 +944,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_wake_word_cooldown_same_id
|
# name: test_wake_word_cooldown_same_id
|
||||||
dict({
|
dict({
|
||||||
|
'conversation_id': 'mock-ulid',
|
||||||
'language': 'en',
|
'language': 'en',
|
||||||
'pipeline': <ANY>,
|
'pipeline': <ANY>,
|
||||||
'runner_data': dict({
|
'runner_data': dict({
|
||||||
@ -931,6 +955,7 @@
|
|||||||
# ---
|
# ---
|
||||||
# name: test_wake_word_cooldown_same_id.1
|
# name: test_wake_word_cooldown_same_id.1
|
||||||
dict({
|
dict({
|
||||||
|
'conversation_id': 'mock-ulid',
|
||||||
'language': 'en',
|
'language': 'en',
|
||||||
'pipeline': <ANY>,
|
'pipeline': <ANY>,
|
||||||
'runner_data': dict({
|
'runner_data': dict({
|
||||||
|
@ -1,11 +1,12 @@
|
|||||||
"""Test Voice Assistant init."""
|
"""Test Voice Assistant init."""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from collections.abc import Generator
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
import itertools as it
|
import itertools as it
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import tempfile
|
import tempfile
|
||||||
from unittest.mock import ANY, patch
|
from unittest.mock import ANY, Mock, patch
|
||||||
import wave
|
import wave
|
||||||
|
|
||||||
import hass_nabucasa
|
import hass_nabucasa
|
||||||
@ -41,6 +42,14 @@ from .conftest import (
|
|||||||
from tests.typing import ClientSessionGenerator, WebSocketGenerator
|
from tests.typing import ClientSessionGenerator, WebSocketGenerator
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_ulid() -> Generator[Mock]:
|
||||||
|
"""Mock the ulid of chat sessions."""
|
||||||
|
with patch("homeassistant.helpers.chat_session.ulid_now") as mock_ulid_now:
|
||||||
|
mock_ulid_now.return_value = "mock-ulid"
|
||||||
|
yield mock_ulid_now
|
||||||
|
|
||||||
|
|
||||||
def process_events(events: list[assist_pipeline.PipelineEvent]) -> list[dict]:
|
def process_events(events: list[assist_pipeline.PipelineEvent]) -> list[dict]:
|
||||||
"""Process events to remove dynamic values."""
|
"""Process events to remove dynamic values."""
|
||||||
processed = []
|
processed = []
|
||||||
@ -684,7 +693,7 @@ async def test_wake_word_detection_aborted(
|
|||||||
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
pipeline = assist_pipeline.pipeline.async_get_pipeline(hass, pipeline_id)
|
||||||
|
|
||||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||||
conversation_id=None,
|
conversation_id="mock-conversation-id",
|
||||||
device_id=None,
|
device_id=None,
|
||||||
stt_metadata=stt.SpeechMetadata(
|
stt_metadata=stt.SpeechMetadata(
|
||||||
language="",
|
language="",
|
||||||
@ -771,7 +780,7 @@ async def test_tts_audio_output(
|
|||||||
|
|
||||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||||
tts_input="This is a test.",
|
tts_input="This is a test.",
|
||||||
conversation_id=None,
|
conversation_id="mock-conversation-id",
|
||||||
device_id=None,
|
device_id=None,
|
||||||
run=assist_pipeline.pipeline.PipelineRun(
|
run=assist_pipeline.pipeline.PipelineRun(
|
||||||
hass,
|
hass,
|
||||||
@ -828,7 +837,7 @@ async def test_tts_wav_preferred_format(
|
|||||||
|
|
||||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||||
tts_input="This is a test.",
|
tts_input="This is a test.",
|
||||||
conversation_id=None,
|
conversation_id="mock-conversation-id",
|
||||||
device_id=None,
|
device_id=None,
|
||||||
run=assist_pipeline.pipeline.PipelineRun(
|
run=assist_pipeline.pipeline.PipelineRun(
|
||||||
hass,
|
hass,
|
||||||
@ -896,7 +905,7 @@ async def test_tts_dict_preferred_format(
|
|||||||
|
|
||||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||||
tts_input="This is a test.",
|
tts_input="This is a test.",
|
||||||
conversation_id=None,
|
conversation_id="mock-conversation-id",
|
||||||
device_id=None,
|
device_id=None,
|
||||||
run=assist_pipeline.pipeline.PipelineRun(
|
run=assist_pipeline.pipeline.PipelineRun(
|
||||||
hass,
|
hass,
|
||||||
@ -982,6 +991,7 @@ async def test_sentence_trigger_overrides_conversation_agent(
|
|||||||
|
|
||||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||||
intent_input="test trigger sentence",
|
intent_input="test trigger sentence",
|
||||||
|
conversation_id="mock-conversation-id",
|
||||||
run=assist_pipeline.pipeline.PipelineRun(
|
run=assist_pipeline.pipeline.PipelineRun(
|
||||||
hass,
|
hass,
|
||||||
context=Context(),
|
context=Context(),
|
||||||
@ -1059,6 +1069,7 @@ async def test_prefer_local_intents(
|
|||||||
|
|
||||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||||
intent_input="I'd like to order a stout please",
|
intent_input="I'd like to order a stout please",
|
||||||
|
conversation_id="mock-conversation-id",
|
||||||
run=assist_pipeline.pipeline.PipelineRun(
|
run=assist_pipeline.pipeline.PipelineRun(
|
||||||
hass,
|
hass,
|
||||||
context=Context(),
|
context=Context(),
|
||||||
@ -1136,6 +1147,7 @@ async def test_stt_language_used_instead_of_conversation_language(
|
|||||||
|
|
||||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||||
intent_input="test input",
|
intent_input="test input",
|
||||||
|
conversation_id="mock-conversation-id",
|
||||||
run=assist_pipeline.pipeline.PipelineRun(
|
run=assist_pipeline.pipeline.PipelineRun(
|
||||||
hass,
|
hass,
|
||||||
context=Context(),
|
context=Context(),
|
||||||
@ -1210,6 +1222,7 @@ async def test_tts_language_used_instead_of_conversation_language(
|
|||||||
|
|
||||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||||
intent_input="test input",
|
intent_input="test input",
|
||||||
|
conversation_id="mock-conversation-id",
|
||||||
run=assist_pipeline.pipeline.PipelineRun(
|
run=assist_pipeline.pipeline.PipelineRun(
|
||||||
hass,
|
hass,
|
||||||
context=Context(),
|
context=Context(),
|
||||||
@ -1284,6 +1297,7 @@ async def test_pipeline_language_used_instead_of_conversation_language(
|
|||||||
|
|
||||||
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
pipeline_input = assist_pipeline.pipeline.PipelineInput(
|
||||||
intent_input="test input",
|
intent_input="test input",
|
||||||
|
conversation_id="mock-conversation-id",
|
||||||
run=assist_pipeline.pipeline.PipelineRun(
|
run=assist_pipeline.pipeline.PipelineRun(
|
||||||
hass,
|
hass,
|
||||||
context=Context(),
|
context=Context(),
|
||||||
|
@ -2,8 +2,9 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import base64
|
import base64
|
||||||
|
from collections.abc import Generator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
from unittest.mock import ANY, patch
|
from unittest.mock import ANY, Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from syrupy.assertion import SnapshotAssertion
|
from syrupy.assertion import SnapshotAssertion
|
||||||
@ -35,6 +36,14 @@ from tests.common import MockConfigEntry
|
|||||||
from tests.typing import WebSocketGenerator
|
from tests.typing import WebSocketGenerator
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_ulid() -> Generator[Mock]:
|
||||||
|
"""Mock the ulid of chat sessions."""
|
||||||
|
with patch("homeassistant.helpers.chat_session.ulid_now") as mock_ulid_now:
|
||||||
|
mock_ulid_now.return_value = "mock-ulid"
|
||||||
|
yield mock_ulid_now
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"extra_msg",
|
"extra_msg",
|
||||||
[
|
[
|
||||||
|
Loading…
x
Reference in New Issue
Block a user