diff --git a/src/components/ha-assist-chat.ts b/src/components/ha-assist-chat.ts index be1629326c..ce069a94ee 100644 --- a/src/components/ha-assist-chat.ts +++ b/src/components/ha-assist-chat.ts @@ -5,8 +5,11 @@ import { customElement, property, query, state } from "lit/decorators"; import { classMap } from "lit/directives/class-map"; import type { HomeAssistant } from "../types"; import { + type PipelineRunEvent, runAssistPipeline, type AssistPipeline, + type ConversationChatLogAssistantDelta, + type ConversationChatLogToolResultDelta, } from "../data/assist_pipeline"; import { supportsFeature } from "../common/entity/supports-feature"; import { ConversationEntityFeature } from "../data/conversation"; @@ -90,7 +93,7 @@ export class HaAssistChat extends LitElement { super.disconnectedCallback(); this._audioRecorder?.close(); this._audioRecorder = undefined; - this._audio?.pause(); + this._unloadAudio(); this._conversation = []; this._conversationId = null; } @@ -109,25 +112,24 @@ export class HaAssistChat extends LitElement { const supportsSTT = this.pipeline?.stt_engine && !this.disableSpeech; return html` - ${controlHA - ? nothing - : html` - - ${this.hass.localize( - "ui.dialogs.voice_command.conversation_no_control" - )} - - `} -
-
- ${this._conversation!.map( - // New lines matter for messages - // prettier-ignore - (message) => html` +
+ ${controlHA + ? nothing + : html` + + ${this.hass.localize( + "ui.dialogs.voice_command.conversation_no_control" + )} + + `} +
+ ${this._conversation!.map( + // New lines matter for messages + // prettier-ignore + (message) => html`
${message.text}
` - )} -
+ )}
{ if (this._audioBuffer) { @@ -293,27 +295,36 @@ export class HaAssistChat extends LitElement { await this._audioRecorder.start(); this._addMessage(userMessage); - this.requestUpdate("_audioRecorder"); - let continueConversation = false; - let hassMessage = { - who: "hass", - text: "…", - error: false, - }; - let currentDeltaRole = ""; - // To make sure the answer is placed at the right user text, we add it before we process it + const hassMessageProcesser = this._createAddHassMessageProcessor(); + try { const unsub = await runAssistPipeline( this.hass, - (event) => { + (event: PipelineRunEvent) => { if (event.type === "run-start") { this._stt_binary_handler_id = event.data.runner_data.stt_binary_handler_id; + this._audio = new Audio(event.data.tts_output!.url); + this._audio.play(); + this._audio.addEventListener("ended", () => { + this._unloadAudio(); + if (hassMessageProcesser.continueConversation) { + this._startListening(); + } + }); + this._audio.addEventListener("pause", this._unloadAudio); + this._audio.addEventListener("canplaythrough", () => + this._audio?.play() + ); + this._audio.addEventListener("error", () => { + this._unloadAudio(); + showAlertDialog(this, { title: "Error playing audio." }); + }); } // When we start STT stage, the WS has a binary handler - if (event.type === "stt-start" && this._audioBuffer) { + else if (event.type === "stt-start" && this._audioBuffer) { // Send the buffer over the WS to the STT engine. for (const buffer of this._audioBuffer) { this._sendAudioChunk(buffer); @@ -322,91 +333,26 @@ export class HaAssistChat extends LitElement { } // Stop recording if the server is done with STT stage - if (event.type === "stt-end") { + else if (event.type === "stt-end") { this._stt_binary_handler_id = undefined; this._stopListening(); userMessage.text = event.data.stt_output.text; this.requestUpdate("_conversation"); - // To make sure the answer is placed at the right user text, we add it before we process it - this._addMessage(hassMessage); - } - - if (event.type === "intent-progress") { - const delta = event.data.chat_log_delta; - - // new message - if (delta.role) { - // If currentDeltaRole exists, it means we're receiving our - // second or later message. Let's add it to the chat. - if (currentDeltaRole && delta.role && hassMessage.text !== "…") { - // Remove progress indicator of previous message - hassMessage.text = hassMessage.text.substring( - 0, - hassMessage.text.length - 1 - ); - - hassMessage = { - who: "hass", - text: "…", - error: false, - }; - this._addMessage(hassMessage); - } - currentDeltaRole = delta.role; - } - - if ( - currentDeltaRole === "assistant" && - "content" in delta && - delta.content - ) { - hassMessage.text = - hassMessage.text.substring(0, hassMessage.text.length - 1) + - delta.content + - "…"; - this.requestUpdate("_conversation"); - } - } - - if (event.type === "intent-end") { - this._conversationId = event.data.intent_output.conversation_id; - continueConversation = - event.data.intent_output.continue_conversation; - const plain = event.data.intent_output.response.speech?.plain; - if (plain) { - hassMessage.text = plain.speech; - } - this.requestUpdate("_conversation"); - } - - if (event.type === "tts-end") { - const url = event.data.tts_output.url; - this._audio = new Audio(url); - this._audio.play(); - this._audio.addEventListener("ended", () => { - this._unloadAudio(); - if (continueConversation) { - this._startListening(); - } - }); - this._audio.addEventListener("pause", this._unloadAudio); - this._audio.addEventListener("canplaythrough", this._playAudio); - this._audio.addEventListener("error", this._audioError); - } - - if (event.type === "run-end") { + // Add the response message placeholder to the chat when we know the STT is done + hassMessageProcesser.addMessage(); + } else if (event.type.startsWith("intent-")) { + hassMessageProcesser.processEvent(event); + } else if (event.type === "run-end") { this._stt_binary_handler_id = undefined; unsub(); - } - - if (event.type === "error") { + } else if (event.type === "error") { + this._unloadAudio(); this._stt_binary_handler_id = undefined; if (userMessage.text === "…") { userMessage.text = event.data.message; userMessage.error = true; } else { - hassMessage.text = event.data.message; - hassMessage.error = true; + hassMessageProcesser.setError(event.data.message); } this._stopListening(); this.requestUpdate("_conversation"); @@ -464,90 +410,33 @@ export class HaAssistChat extends LitElement { this.hass.connection.socket!.send(data); } - private _playAudio = () => { - this._audio?.play(); - }; - - private _audioError = () => { - showAlertDialog(this, { title: "Error playing audio." }); - this._audio?.removeAttribute("src"); - }; - private _unloadAudio = () => { - this._audio?.removeAttribute("src"); + if (!this._audio) { + return; + } + this._audio.pause(); + this._audio.removeAttribute("src"); this._audio = undefined; }; private async _processText(text: string) { + this._unloadAudio(); this._processing = true; - this._audio?.pause(); this._addMessage({ who: "user", text }); - let hassMessage = { - who: "hass", - text: "…", - error: false, - }; - let currentDeltaRole = ""; - // To make sure the answer is placed at the right user text, we add it before we process it - this._addMessage(hassMessage); + const hassMessageProcesser = this._createAddHassMessageProcessor(); + hassMessageProcesser.addMessage(); try { const unsub = await runAssistPipeline( this.hass, (event) => { - if (event.type === "intent-progress") { - const delta = event.data.chat_log_delta; - - // new message and previous message has content - if (delta.role) { - // If currentDeltaRole exists, it means we're receiving our - // second or later message. Let's add it to the chat. - if ( - currentDeltaRole && - delta.role === "assistant" && - hassMessage.text !== "…" - ) { - // Remove progress indicator of previous message - hassMessage.text = hassMessage.text.substring( - 0, - hassMessage.text.length - 1 - ); - - hassMessage = { - who: "hass", - text: "…", - error: false, - }; - this._addMessage(hassMessage); - } - currentDeltaRole = delta.role; - } - - if ( - currentDeltaRole === "assistant" && - "content" in delta && - delta.content - ) { - hassMessage.text = - hassMessage.text.substring(0, hassMessage.text.length - 1) + - delta.content + - "…"; - this.requestUpdate("_conversation"); - } + if (event.type.startsWith("intent-")) { + hassMessageProcesser.processEvent(event); } - if (event.type === "intent-end") { - this._conversationId = event.data.intent_output.conversation_id; - const plain = event.data.intent_output.response.speech?.plain; - if (plain) { - hassMessage.text = plain.speech; - } - this.requestUpdate("_conversation"); unsub(); } if (event.type === "error") { - hassMessage.text = event.data.message; - hassMessage.error = true; - this.requestUpdate("_conversation"); + hassMessageProcesser.setError(event.data.message); unsub(); } }, @@ -560,20 +449,126 @@ export class HaAssistChat extends LitElement { } ); } catch { - hassMessage.text = this.hass.localize("ui.dialogs.voice_command.error"); - hassMessage.error = true; - this.requestUpdate("_conversation"); + hassMessageProcesser.setError( + this.hass.localize("ui.dialogs.voice_command.error") + ); } finally { this._processing = false; } } + private _createAddHassMessageProcessor() { + let currentDeltaRole = ""; + + const progressToNextMessage = () => { + if (progress.hassMessage.text === "…") { + return; + } + progress.hassMessage.text = progress.hassMessage.text.substring( + 0, + progress.hassMessage.text.length - 1 + ); + + progress.hassMessage = { + who: "hass", + text: "…", + error: false, + }; + this._addMessage(progress.hassMessage); + }; + + const isAssistantDelta = ( + _delta: any + ): _delta is Partial => + currentDeltaRole === "assistant"; + + const isToolResult = ( + _delta: any + ): _delta is ConversationChatLogToolResultDelta => + currentDeltaRole === "tool_result"; + + const tools: Record< + string, + ConversationChatLogAssistantDelta["tool_calls"][0] + > = {}; + + const progress = { + continueConversation: false, + hassMessage: { + who: "hass", + text: "…", + error: false, + }, + addMessage: () => { + this._addMessage(progress.hassMessage); + }, + setError: (error: string) => { + progressToNextMessage(); + progress.hassMessage.text = error; + progress.hassMessage.error = true; + this.requestUpdate("_conversation"); + }, + processEvent: (event: PipelineRunEvent) => { + if (event.type === "intent-progress") { + const delta = event.data.chat_log_delta; + + // new message + if (delta.role) { + progressToNextMessage(); + currentDeltaRole = delta.role; + } + + if (isAssistantDelta(delta)) { + if (delta.content) { + progress.hassMessage.text = + progress.hassMessage.text.substring( + 0, + progress.hassMessage.text.length - 1 + ) + + delta.content + + "…"; + this.requestUpdate("_conversation"); + } + if (delta.tool_calls) { + for (const toolCall of delta.tool_calls) { + tools[toolCall.id] = toolCall; + } + } + } else if (isToolResult(delta)) { + if (tools[delta.tool_call_id]) { + delete tools[delta.tool_call_id]; + } + } + } else if (event.type === "intent-end") { + this._conversationId = event.data.intent_output.conversation_id; + progress.continueConversation = + event.data.intent_output.continue_conversation; + const response = + event.data.intent_output.response.speech?.plain.speech; + if (!response) { + return; + } + if (event.data.intent_output.response.response_type === "error") { + progress.setError(response); + } else { + progress.hassMessage.text = response; + this.requestUpdate("_conversation"); + } + } + }, + }; + return progress; + } + static styles = css` :host { flex: 1; display: flex; flex-direction: column; } + ha-alert { + margin-bottom: 8px; + } ha-textfield { display: block; } @@ -581,17 +576,14 @@ export class HaAssistChat extends LitElement { flex: 1; display: block; box-sizing: border-box; - position: relative; - } - .messages-container { - position: absolute; - bottom: 0px; - right: 0px; - left: 0px; - padding: 0px 10px 16px; - box-sizing: border-box; overflow-y: auto; max-height: 100%; + display: flex; + flex-direction: column; + padding: 0 12px 16px; + } + .spacer { + flex: 1; } .message { white-space: pre-line; @@ -601,6 +593,9 @@ export class HaAssistChat extends LitElement { padding: 8px; border-radius: 15px; } + .message:last-child { + margin-bottom: 0; + } @media all and (max-width: 450px), all and (max-height: 500px) { .message { @@ -619,7 +614,7 @@ export class HaAssistChat extends LitElement { margin-left: 24px; margin-inline-start: 24px; margin-inline-end: initial; - float: var(--float-end); + align-self: flex-end; text-align: right; border-bottom-right-radius: 0px; background-color: var(--chat-background-color-user, var(--primary-color)); @@ -631,7 +626,7 @@ export class HaAssistChat extends LitElement { margin-right: 24px; margin-inline-end: 24px; margin-inline-start: initial; - float: var(--float-start); + align-self: flex-start; border-bottom-left-radius: 0px; background-color: var( --chat-background-color-hass, diff --git a/src/data/assist_pipeline.ts b/src/data/assist_pipeline.ts index 72da3c5e9a..8c6d10949a 100644 --- a/src/data/assist_pipeline.ts +++ b/src/data/assist_pipeline.ts @@ -1,6 +1,5 @@ import type { HomeAssistant } from "../types"; import type { ConversationResult } from "./conversation"; -import type { ResolvedMediaSource } from "./media_source"; import type { SpeechMetadata } from "./stt"; export interface AssistPipeline { @@ -53,10 +52,16 @@ interface PipelineRunStartEvent extends PipelineEventBase { data: { pipeline: string; language: string; + conversation_id: string; runner_data: { stt_binary_handler_id: number | null; timeout: number; }; + tts_output?: { + token: string; + url: string; + mime_type: string; + }; }; } interface PipelineRunEndEvent extends PipelineEventBase { @@ -109,7 +114,7 @@ interface PipelineIntentStartEvent extends PipelineEventBase { }; } -interface ConversationChatLogAssistantDelta { +export interface ConversationChatLogAssistantDelta { role: "assistant"; content: string; tool_calls: { @@ -119,7 +124,7 @@ interface ConversationChatLogAssistantDelta { }[]; } -interface ConversationChatLogToolResultDelta { +export interface ConversationChatLogToolResultDelta { role: "tool_result"; agent_id: string; tool_call_id: string; @@ -156,7 +161,12 @@ interface PipelineTTSStartEvent extends PipelineEventBase { interface PipelineTTSEndEvent extends PipelineEventBase { type: "tts-end"; data: { - tts_output: ResolvedMediaSource; + tts_output: { + media_id: string; + token: string; + url: string; + mime_type: string; + }; }; }