Tweaks to Assist chat dialog (#25494)

This commit is contained in:
Paulus Schoutsen 2025-05-20 02:39:56 -04:00 committed by GitHub
parent a026c72230
commit 83df10ef29
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 195 additions and 190 deletions

View File

@ -5,8 +5,11 @@ import { customElement, property, query, state } from "lit/decorators";
import { classMap } from "lit/directives/class-map"; import { classMap } from "lit/directives/class-map";
import type { HomeAssistant } from "../types"; import type { HomeAssistant } from "../types";
import { import {
type PipelineRunEvent,
runAssistPipeline, runAssistPipeline,
type AssistPipeline, type AssistPipeline,
type ConversationChatLogAssistantDelta,
type ConversationChatLogToolResultDelta,
} from "../data/assist_pipeline"; } from "../data/assist_pipeline";
import { supportsFeature } from "../common/entity/supports-feature"; import { supportsFeature } from "../common/entity/supports-feature";
import { ConversationEntityFeature } from "../data/conversation"; import { ConversationEntityFeature } from "../data/conversation";
@ -90,7 +93,7 @@ export class HaAssistChat extends LitElement {
super.disconnectedCallback(); super.disconnectedCallback();
this._audioRecorder?.close(); this._audioRecorder?.close();
this._audioRecorder = undefined; this._audioRecorder = undefined;
this._audio?.pause(); this._unloadAudio();
this._conversation = []; this._conversation = [];
this._conversationId = null; this._conversationId = null;
} }
@ -109,6 +112,7 @@ export class HaAssistChat extends LitElement {
const supportsSTT = this.pipeline?.stt_engine && !this.disableSpeech; const supportsSTT = this.pipeline?.stt_engine && !this.disableSpeech;
return html` return html`
<div class="messages" id="scroll-container">
${controlHA ${controlHA
? nothing ? nothing
: html` : html`
@ -118,8 +122,7 @@ export class HaAssistChat extends LitElement {
)} )}
</ha-alert> </ha-alert>
`} `}
<div class="messages"> <div class="spacer"></div>
<div class="messages-container" id="scroll-container">
${this._conversation!.map( ${this._conversation!.map(
// New lines matter for messages // New lines matter for messages
// prettier-ignore // prettier-ignore
@ -128,7 +131,6 @@ export class HaAssistChat extends LitElement {
` `
)} )}
</div> </div>
</div>
<div class="input" slot="primaryAction"> <div class="input" slot="primaryAction">
<ha-textfield <ha-textfield
id="message-input" id="message-input"
@ -273,8 +275,8 @@ export class HaAssistChat extends LitElement {
} }
private async _startListening() { private async _startListening() {
this._unloadAudio();
this._processing = true; this._processing = true;
this._audio?.pause();
if (!this._audioRecorder) { if (!this._audioRecorder) {
this._audioRecorder = new AudioRecorder((audio) => { this._audioRecorder = new AudioRecorder((audio) => {
if (this._audioBuffer) { if (this._audioBuffer) {
@ -293,27 +295,36 @@ export class HaAssistChat extends LitElement {
await this._audioRecorder.start(); await this._audioRecorder.start();
this._addMessage(userMessage); this._addMessage(userMessage);
this.requestUpdate("_audioRecorder");
let continueConversation = false; const hassMessageProcesser = this._createAddHassMessageProcessor();
let hassMessage = {
who: "hass",
text: "…",
error: false,
};
let currentDeltaRole = "";
// To make sure the answer is placed at the right user text, we add it before we process it
try { try {
const unsub = await runAssistPipeline( const unsub = await runAssistPipeline(
this.hass, this.hass,
(event) => { (event: PipelineRunEvent) => {
if (event.type === "run-start") { if (event.type === "run-start") {
this._stt_binary_handler_id = this._stt_binary_handler_id =
event.data.runner_data.stt_binary_handler_id; event.data.runner_data.stt_binary_handler_id;
this._audio = new Audio(event.data.tts_output!.url);
this._audio.play();
this._audio.addEventListener("ended", () => {
this._unloadAudio();
if (hassMessageProcesser.continueConversation) {
this._startListening();
}
});
this._audio.addEventListener("pause", this._unloadAudio);
this._audio.addEventListener("canplaythrough", () =>
this._audio?.play()
);
this._audio.addEventListener("error", () => {
this._unloadAudio();
showAlertDialog(this, { title: "Error playing audio." });
});
} }
// When we start STT stage, the WS has a binary handler // When we start STT stage, the WS has a binary handler
if (event.type === "stt-start" && this._audioBuffer) { else if (event.type === "stt-start" && this._audioBuffer) {
// Send the buffer over the WS to the STT engine. // Send the buffer over the WS to the STT engine.
for (const buffer of this._audioBuffer) { for (const buffer of this._audioBuffer) {
this._sendAudioChunk(buffer); this._sendAudioChunk(buffer);
@ -322,91 +333,26 @@ export class HaAssistChat extends LitElement {
} }
// Stop recording if the server is done with STT stage // Stop recording if the server is done with STT stage
if (event.type === "stt-end") { else if (event.type === "stt-end") {
this._stt_binary_handler_id = undefined; this._stt_binary_handler_id = undefined;
this._stopListening(); this._stopListening();
userMessage.text = event.data.stt_output.text; userMessage.text = event.data.stt_output.text;
this.requestUpdate("_conversation"); this.requestUpdate("_conversation");
// To make sure the answer is placed at the right user text, we add it before we process it // Add the response message placeholder to the chat when we know the STT is done
this._addMessage(hassMessage); hassMessageProcesser.addMessage();
} } else if (event.type.startsWith("intent-")) {
hassMessageProcesser.processEvent(event);
if (event.type === "intent-progress") { } else if (event.type === "run-end") {
const delta = event.data.chat_log_delta;
// new message
if (delta.role) {
// If currentDeltaRole exists, it means we're receiving our
// second or later message. Let's add it to the chat.
if (currentDeltaRole && delta.role && hassMessage.text !== "…") {
// Remove progress indicator of previous message
hassMessage.text = hassMessage.text.substring(
0,
hassMessage.text.length - 1
);
hassMessage = {
who: "hass",
text: "…",
error: false,
};
this._addMessage(hassMessage);
}
currentDeltaRole = delta.role;
}
if (
currentDeltaRole === "assistant" &&
"content" in delta &&
delta.content
) {
hassMessage.text =
hassMessage.text.substring(0, hassMessage.text.length - 1) +
delta.content +
"…";
this.requestUpdate("_conversation");
}
}
if (event.type === "intent-end") {
this._conversationId = event.data.intent_output.conversation_id;
continueConversation =
event.data.intent_output.continue_conversation;
const plain = event.data.intent_output.response.speech?.plain;
if (plain) {
hassMessage.text = plain.speech;
}
this.requestUpdate("_conversation");
}
if (event.type === "tts-end") {
const url = event.data.tts_output.url;
this._audio = new Audio(url);
this._audio.play();
this._audio.addEventListener("ended", () => {
this._unloadAudio();
if (continueConversation) {
this._startListening();
}
});
this._audio.addEventListener("pause", this._unloadAudio);
this._audio.addEventListener("canplaythrough", this._playAudio);
this._audio.addEventListener("error", this._audioError);
}
if (event.type === "run-end") {
this._stt_binary_handler_id = undefined; this._stt_binary_handler_id = undefined;
unsub(); unsub();
} } else if (event.type === "error") {
this._unloadAudio();
if (event.type === "error") {
this._stt_binary_handler_id = undefined; this._stt_binary_handler_id = undefined;
if (userMessage.text === "…") { if (userMessage.text === "…") {
userMessage.text = event.data.message; userMessage.text = event.data.message;
userMessage.error = true; userMessage.error = true;
} else { } else {
hassMessage.text = event.data.message; hassMessageProcesser.setError(event.data.message);
hassMessage.error = true;
} }
this._stopListening(); this._stopListening();
this.requestUpdate("_conversation"); this.requestUpdate("_conversation");
@ -464,90 +410,33 @@ export class HaAssistChat extends LitElement {
this.hass.connection.socket!.send(data); this.hass.connection.socket!.send(data);
} }
private _playAudio = () => {
this._audio?.play();
};
private _audioError = () => {
showAlertDialog(this, { title: "Error playing audio." });
this._audio?.removeAttribute("src");
};
private _unloadAudio = () => { private _unloadAudio = () => {
this._audio?.removeAttribute("src"); if (!this._audio) {
return;
}
this._audio.pause();
this._audio.removeAttribute("src");
this._audio = undefined; this._audio = undefined;
}; };
private async _processText(text: string) { private async _processText(text: string) {
this._unloadAudio();
this._processing = true; this._processing = true;
this._audio?.pause();
this._addMessage({ who: "user", text }); this._addMessage({ who: "user", text });
let hassMessage = { const hassMessageProcesser = this._createAddHassMessageProcessor();
who: "hass", hassMessageProcesser.addMessage();
text: "…",
error: false,
};
let currentDeltaRole = "";
// To make sure the answer is placed at the right user text, we add it before we process it
this._addMessage(hassMessage);
try { try {
const unsub = await runAssistPipeline( const unsub = await runAssistPipeline(
this.hass, this.hass,
(event) => { (event) => {
if (event.type === "intent-progress") { if (event.type.startsWith("intent-")) {
const delta = event.data.chat_log_delta; hassMessageProcesser.processEvent(event);
// new message and previous message has content
if (delta.role) {
// If currentDeltaRole exists, it means we're receiving our
// second or later message. Let's add it to the chat.
if (
currentDeltaRole &&
delta.role === "assistant" &&
hassMessage.text !== "…"
) {
// Remove progress indicator of previous message
hassMessage.text = hassMessage.text.substring(
0,
hassMessage.text.length - 1
);
hassMessage = {
who: "hass",
text: "…",
error: false,
};
this._addMessage(hassMessage);
} }
currentDeltaRole = delta.role;
}
if (
currentDeltaRole === "assistant" &&
"content" in delta &&
delta.content
) {
hassMessage.text =
hassMessage.text.substring(0, hassMessage.text.length - 1) +
delta.content +
"…";
this.requestUpdate("_conversation");
}
}
if (event.type === "intent-end") { if (event.type === "intent-end") {
this._conversationId = event.data.intent_output.conversation_id;
const plain = event.data.intent_output.response.speech?.plain;
if (plain) {
hassMessage.text = plain.speech;
}
this.requestUpdate("_conversation");
unsub(); unsub();
} }
if (event.type === "error") { if (event.type === "error") {
hassMessage.text = event.data.message; hassMessageProcesser.setError(event.data.message);
hassMessage.error = true;
this.requestUpdate("_conversation");
unsub(); unsub();
} }
}, },
@ -560,20 +449,126 @@ export class HaAssistChat extends LitElement {
} }
); );
} catch { } catch {
hassMessage.text = this.hass.localize("ui.dialogs.voice_command.error"); hassMessageProcesser.setError(
hassMessage.error = true; this.hass.localize("ui.dialogs.voice_command.error")
this.requestUpdate("_conversation"); );
} finally { } finally {
this._processing = false; this._processing = false;
} }
} }
private _createAddHassMessageProcessor() {
let currentDeltaRole = "";
const progressToNextMessage = () => {
if (progress.hassMessage.text === "…") {
return;
}
progress.hassMessage.text = progress.hassMessage.text.substring(
0,
progress.hassMessage.text.length - 1
);
progress.hassMessage = {
who: "hass",
text: "…",
error: false,
};
this._addMessage(progress.hassMessage);
};
const isAssistantDelta = (
_delta: any
): _delta is Partial<ConversationChatLogAssistantDelta> =>
currentDeltaRole === "assistant";
const isToolResult = (
_delta: any
): _delta is ConversationChatLogToolResultDelta =>
currentDeltaRole === "tool_result";
const tools: Record<
string,
ConversationChatLogAssistantDelta["tool_calls"][0]
> = {};
const progress = {
continueConversation: false,
hassMessage: {
who: "hass",
text: "…",
error: false,
},
addMessage: () => {
this._addMessage(progress.hassMessage);
},
setError: (error: string) => {
progressToNextMessage();
progress.hassMessage.text = error;
progress.hassMessage.error = true;
this.requestUpdate("_conversation");
},
processEvent: (event: PipelineRunEvent) => {
if (event.type === "intent-progress") {
const delta = event.data.chat_log_delta;
// new message
if (delta.role) {
progressToNextMessage();
currentDeltaRole = delta.role;
}
if (isAssistantDelta(delta)) {
if (delta.content) {
progress.hassMessage.text =
progress.hassMessage.text.substring(
0,
progress.hassMessage.text.length - 1
) +
delta.content +
"…";
this.requestUpdate("_conversation");
}
if (delta.tool_calls) {
for (const toolCall of delta.tool_calls) {
tools[toolCall.id] = toolCall;
}
}
} else if (isToolResult(delta)) {
if (tools[delta.tool_call_id]) {
delete tools[delta.tool_call_id];
}
}
} else if (event.type === "intent-end") {
this._conversationId = event.data.intent_output.conversation_id;
progress.continueConversation =
event.data.intent_output.continue_conversation;
const response =
event.data.intent_output.response.speech?.plain.speech;
if (!response) {
return;
}
if (event.data.intent_output.response.response_type === "error") {
progress.setError(response);
} else {
progress.hassMessage.text = response;
this.requestUpdate("_conversation");
}
}
},
};
return progress;
}
static styles = css` static styles = css`
:host { :host {
flex: 1; flex: 1;
display: flex; display: flex;
flex-direction: column; flex-direction: column;
} }
ha-alert {
margin-bottom: 8px;
}
ha-textfield { ha-textfield {
display: block; display: block;
} }
@ -581,17 +576,14 @@ export class HaAssistChat extends LitElement {
flex: 1; flex: 1;
display: block; display: block;
box-sizing: border-box; box-sizing: border-box;
position: relative;
}
.messages-container {
position: absolute;
bottom: 0px;
right: 0px;
left: 0px;
padding: 0px 10px 16px;
box-sizing: border-box;
overflow-y: auto; overflow-y: auto;
max-height: 100%; max-height: 100%;
display: flex;
flex-direction: column;
padding: 0 12px 16px;
}
.spacer {
flex: 1;
} }
.message { .message {
white-space: pre-line; white-space: pre-line;
@ -601,6 +593,9 @@ export class HaAssistChat extends LitElement {
padding: 8px; padding: 8px;
border-radius: 15px; border-radius: 15px;
} }
.message:last-child {
margin-bottom: 0;
}
@media all and (max-width: 450px), all and (max-height: 500px) { @media all and (max-width: 450px), all and (max-height: 500px) {
.message { .message {
@ -619,7 +614,7 @@ export class HaAssistChat extends LitElement {
margin-left: 24px; margin-left: 24px;
margin-inline-start: 24px; margin-inline-start: 24px;
margin-inline-end: initial; margin-inline-end: initial;
float: var(--float-end); align-self: flex-end;
text-align: right; text-align: right;
border-bottom-right-radius: 0px; border-bottom-right-radius: 0px;
background-color: var(--chat-background-color-user, var(--primary-color)); background-color: var(--chat-background-color-user, var(--primary-color));
@ -631,7 +626,7 @@ export class HaAssistChat extends LitElement {
margin-right: 24px; margin-right: 24px;
margin-inline-end: 24px; margin-inline-end: 24px;
margin-inline-start: initial; margin-inline-start: initial;
float: var(--float-start); align-self: flex-start;
border-bottom-left-radius: 0px; border-bottom-left-radius: 0px;
background-color: var( background-color: var(
--chat-background-color-hass, --chat-background-color-hass,

View File

@ -1,6 +1,5 @@
import type { HomeAssistant } from "../types"; import type { HomeAssistant } from "../types";
import type { ConversationResult } from "./conversation"; import type { ConversationResult } from "./conversation";
import type { ResolvedMediaSource } from "./media_source";
import type { SpeechMetadata } from "./stt"; import type { SpeechMetadata } from "./stt";
export interface AssistPipeline { export interface AssistPipeline {
@ -53,10 +52,16 @@ interface PipelineRunStartEvent extends PipelineEventBase {
data: { data: {
pipeline: string; pipeline: string;
language: string; language: string;
conversation_id: string;
runner_data: { runner_data: {
stt_binary_handler_id: number | null; stt_binary_handler_id: number | null;
timeout: number; timeout: number;
}; };
tts_output?: {
token: string;
url: string;
mime_type: string;
};
}; };
} }
interface PipelineRunEndEvent extends PipelineEventBase { interface PipelineRunEndEvent extends PipelineEventBase {
@ -109,7 +114,7 @@ interface PipelineIntentStartEvent extends PipelineEventBase {
}; };
} }
interface ConversationChatLogAssistantDelta { export interface ConversationChatLogAssistantDelta {
role: "assistant"; role: "assistant";
content: string; content: string;
tool_calls: { tool_calls: {
@ -119,7 +124,7 @@ interface ConversationChatLogAssistantDelta {
}[]; }[];
} }
interface ConversationChatLogToolResultDelta { export interface ConversationChatLogToolResultDelta {
role: "tool_result"; role: "tool_result";
agent_id: string; agent_id: string;
tool_call_id: string; tool_call_id: string;
@ -156,7 +161,12 @@ interface PipelineTTSStartEvent extends PipelineEventBase {
interface PipelineTTSEndEvent extends PipelineEventBase { interface PipelineTTSEndEvent extends PipelineEventBase {
type: "tts-end"; type: "tts-end";
data: { data: {
tts_output: ResolvedMediaSource; tts_output: {
media_id: string;
token: string;
url: string;
mime_type: string;
};
}; };
} }