Tweaks to Assist chat dialog (#25494)

This commit is contained in:
Paulus Schoutsen 2025-05-20 02:39:56 -04:00 committed by GitHub
parent a026c72230
commit 83df10ef29
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 195 additions and 190 deletions

View File

@ -5,8 +5,11 @@ import { customElement, property, query, state } from "lit/decorators";
import { classMap } from "lit/directives/class-map";
import type { HomeAssistant } from "../types";
import {
type PipelineRunEvent,
runAssistPipeline,
type AssistPipeline,
type ConversationChatLogAssistantDelta,
type ConversationChatLogToolResultDelta,
} from "../data/assist_pipeline";
import { supportsFeature } from "../common/entity/supports-feature";
import { ConversationEntityFeature } from "../data/conversation";
@ -90,7 +93,7 @@ export class HaAssistChat extends LitElement {
super.disconnectedCallback();
this._audioRecorder?.close();
this._audioRecorder = undefined;
this._audio?.pause();
this._unloadAudio();
this._conversation = [];
this._conversationId = null;
}
@ -109,25 +112,24 @@ export class HaAssistChat extends LitElement {
const supportsSTT = this.pipeline?.stt_engine && !this.disableSpeech;
return html`
${controlHA
? nothing
: html`
<ha-alert>
${this.hass.localize(
"ui.dialogs.voice_command.conversation_no_control"
)}
</ha-alert>
`}
<div class="messages">
<div class="messages-container" id="scroll-container">
${this._conversation!.map(
// New lines matter for messages
// prettier-ignore
(message) => html`
<div class="messages" id="scroll-container">
${controlHA
? nothing
: html`
<ha-alert>
${this.hass.localize(
"ui.dialogs.voice_command.conversation_no_control"
)}
</ha-alert>
`}
<div class="spacer"></div>
${this._conversation!.map(
// New lines matter for messages
// prettier-ignore
(message) => html`
<div class="message ${classMap({ error: !!message.error, [message.who]: true })}">${message.text}</div>
`
)}
</div>
)}
</div>
<div class="input" slot="primaryAction">
<ha-textfield
@ -273,8 +275,8 @@ export class HaAssistChat extends LitElement {
}
private async _startListening() {
this._unloadAudio();
this._processing = true;
this._audio?.pause();
if (!this._audioRecorder) {
this._audioRecorder = new AudioRecorder((audio) => {
if (this._audioBuffer) {
@ -293,27 +295,36 @@ export class HaAssistChat extends LitElement {
await this._audioRecorder.start();
this._addMessage(userMessage);
this.requestUpdate("_audioRecorder");
let continueConversation = false;
let hassMessage = {
who: "hass",
text: "…",
error: false,
};
let currentDeltaRole = "";
// To make sure the answer is placed at the right user text, we add it before we process it
const hassMessageProcesser = this._createAddHassMessageProcessor();
try {
const unsub = await runAssistPipeline(
this.hass,
(event) => {
(event: PipelineRunEvent) => {
if (event.type === "run-start") {
this._stt_binary_handler_id =
event.data.runner_data.stt_binary_handler_id;
this._audio = new Audio(event.data.tts_output!.url);
this._audio.play();
this._audio.addEventListener("ended", () => {
this._unloadAudio();
if (hassMessageProcesser.continueConversation) {
this._startListening();
}
});
this._audio.addEventListener("pause", this._unloadAudio);
this._audio.addEventListener("canplaythrough", () =>
this._audio?.play()
);
this._audio.addEventListener("error", () => {
this._unloadAudio();
showAlertDialog(this, { title: "Error playing audio." });
});
}
// When we start STT stage, the WS has a binary handler
if (event.type === "stt-start" && this._audioBuffer) {
else if (event.type === "stt-start" && this._audioBuffer) {
// Send the buffer over the WS to the STT engine.
for (const buffer of this._audioBuffer) {
this._sendAudioChunk(buffer);
@ -322,91 +333,26 @@ export class HaAssistChat extends LitElement {
}
// Stop recording if the server is done with STT stage
if (event.type === "stt-end") {
else if (event.type === "stt-end") {
this._stt_binary_handler_id = undefined;
this._stopListening();
userMessage.text = event.data.stt_output.text;
this.requestUpdate("_conversation");
// To make sure the answer is placed at the right user text, we add it before we process it
this._addMessage(hassMessage);
}
if (event.type === "intent-progress") {
const delta = event.data.chat_log_delta;
// new message
if (delta.role) {
// If currentDeltaRole exists, it means we're receiving our
// second or later message. Let's add it to the chat.
if (currentDeltaRole && delta.role && hassMessage.text !== "…") {
// Remove progress indicator of previous message
hassMessage.text = hassMessage.text.substring(
0,
hassMessage.text.length - 1
);
hassMessage = {
who: "hass",
text: "…",
error: false,
};
this._addMessage(hassMessage);
}
currentDeltaRole = delta.role;
}
if (
currentDeltaRole === "assistant" &&
"content" in delta &&
delta.content
) {
hassMessage.text =
hassMessage.text.substring(0, hassMessage.text.length - 1) +
delta.content +
"…";
this.requestUpdate("_conversation");
}
}
if (event.type === "intent-end") {
this._conversationId = event.data.intent_output.conversation_id;
continueConversation =
event.data.intent_output.continue_conversation;
const plain = event.data.intent_output.response.speech?.plain;
if (plain) {
hassMessage.text = plain.speech;
}
this.requestUpdate("_conversation");
}
if (event.type === "tts-end") {
const url = event.data.tts_output.url;
this._audio = new Audio(url);
this._audio.play();
this._audio.addEventListener("ended", () => {
this._unloadAudio();
if (continueConversation) {
this._startListening();
}
});
this._audio.addEventListener("pause", this._unloadAudio);
this._audio.addEventListener("canplaythrough", this._playAudio);
this._audio.addEventListener("error", this._audioError);
}
if (event.type === "run-end") {
// Add the response message placeholder to the chat when we know the STT is done
hassMessageProcesser.addMessage();
} else if (event.type.startsWith("intent-")) {
hassMessageProcesser.processEvent(event);
} else if (event.type === "run-end") {
this._stt_binary_handler_id = undefined;
unsub();
}
if (event.type === "error") {
} else if (event.type === "error") {
this._unloadAudio();
this._stt_binary_handler_id = undefined;
if (userMessage.text === "…") {
userMessage.text = event.data.message;
userMessage.error = true;
} else {
hassMessage.text = event.data.message;
hassMessage.error = true;
hassMessageProcesser.setError(event.data.message);
}
this._stopListening();
this.requestUpdate("_conversation");
@ -464,90 +410,33 @@ export class HaAssistChat extends LitElement {
this.hass.connection.socket!.send(data);
}
private _playAudio = () => {
this._audio?.play();
};
private _audioError = () => {
showAlertDialog(this, { title: "Error playing audio." });
this._audio?.removeAttribute("src");
};
private _unloadAudio = () => {
this._audio?.removeAttribute("src");
if (!this._audio) {
return;
}
this._audio.pause();
this._audio.removeAttribute("src");
this._audio = undefined;
};
private async _processText(text: string) {
this._unloadAudio();
this._processing = true;
this._audio?.pause();
this._addMessage({ who: "user", text });
let hassMessage = {
who: "hass",
text: "…",
error: false,
};
let currentDeltaRole = "";
// To make sure the answer is placed at the right user text, we add it before we process it
this._addMessage(hassMessage);
const hassMessageProcesser = this._createAddHassMessageProcessor();
hassMessageProcesser.addMessage();
try {
const unsub = await runAssistPipeline(
this.hass,
(event) => {
if (event.type === "intent-progress") {
const delta = event.data.chat_log_delta;
// new message and previous message has content
if (delta.role) {
// If currentDeltaRole exists, it means we're receiving our
// second or later message. Let's add it to the chat.
if (
currentDeltaRole &&
delta.role === "assistant" &&
hassMessage.text !== "…"
) {
// Remove progress indicator of previous message
hassMessage.text = hassMessage.text.substring(
0,
hassMessage.text.length - 1
);
hassMessage = {
who: "hass",
text: "…",
error: false,
};
this._addMessage(hassMessage);
}
currentDeltaRole = delta.role;
}
if (
currentDeltaRole === "assistant" &&
"content" in delta &&
delta.content
) {
hassMessage.text =
hassMessage.text.substring(0, hassMessage.text.length - 1) +
delta.content +
"…";
this.requestUpdate("_conversation");
}
if (event.type.startsWith("intent-")) {
hassMessageProcesser.processEvent(event);
}
if (event.type === "intent-end") {
this._conversationId = event.data.intent_output.conversation_id;
const plain = event.data.intent_output.response.speech?.plain;
if (plain) {
hassMessage.text = plain.speech;
}
this.requestUpdate("_conversation");
unsub();
}
if (event.type === "error") {
hassMessage.text = event.data.message;
hassMessage.error = true;
this.requestUpdate("_conversation");
hassMessageProcesser.setError(event.data.message);
unsub();
}
},
@ -560,20 +449,126 @@ export class HaAssistChat extends LitElement {
}
);
} catch {
hassMessage.text = this.hass.localize("ui.dialogs.voice_command.error");
hassMessage.error = true;
this.requestUpdate("_conversation");
hassMessageProcesser.setError(
this.hass.localize("ui.dialogs.voice_command.error")
);
} finally {
this._processing = false;
}
}
private _createAddHassMessageProcessor() {
let currentDeltaRole = "";
const progressToNextMessage = () => {
if (progress.hassMessage.text === "…") {
return;
}
progress.hassMessage.text = progress.hassMessage.text.substring(
0,
progress.hassMessage.text.length - 1
);
progress.hassMessage = {
who: "hass",
text: "…",
error: false,
};
this._addMessage(progress.hassMessage);
};
const isAssistantDelta = (
_delta: any
): _delta is Partial<ConversationChatLogAssistantDelta> =>
currentDeltaRole === "assistant";
const isToolResult = (
_delta: any
): _delta is ConversationChatLogToolResultDelta =>
currentDeltaRole === "tool_result";
const tools: Record<
string,
ConversationChatLogAssistantDelta["tool_calls"][0]
> = {};
const progress = {
continueConversation: false,
hassMessage: {
who: "hass",
text: "…",
error: false,
},
addMessage: () => {
this._addMessage(progress.hassMessage);
},
setError: (error: string) => {
progressToNextMessage();
progress.hassMessage.text = error;
progress.hassMessage.error = true;
this.requestUpdate("_conversation");
},
processEvent: (event: PipelineRunEvent) => {
if (event.type === "intent-progress") {
const delta = event.data.chat_log_delta;
// new message
if (delta.role) {
progressToNextMessage();
currentDeltaRole = delta.role;
}
if (isAssistantDelta(delta)) {
if (delta.content) {
progress.hassMessage.text =
progress.hassMessage.text.substring(
0,
progress.hassMessage.text.length - 1
) +
delta.content +
"…";
this.requestUpdate("_conversation");
}
if (delta.tool_calls) {
for (const toolCall of delta.tool_calls) {
tools[toolCall.id] = toolCall;
}
}
} else if (isToolResult(delta)) {
if (tools[delta.tool_call_id]) {
delete tools[delta.tool_call_id];
}
}
} else if (event.type === "intent-end") {
this._conversationId = event.data.intent_output.conversation_id;
progress.continueConversation =
event.data.intent_output.continue_conversation;
const response =
event.data.intent_output.response.speech?.plain.speech;
if (!response) {
return;
}
if (event.data.intent_output.response.response_type === "error") {
progress.setError(response);
} else {
progress.hassMessage.text = response;
this.requestUpdate("_conversation");
}
}
},
};
return progress;
}
static styles = css`
:host {
flex: 1;
display: flex;
flex-direction: column;
}
ha-alert {
margin-bottom: 8px;
}
ha-textfield {
display: block;
}
@ -581,17 +576,14 @@ export class HaAssistChat extends LitElement {
flex: 1;
display: block;
box-sizing: border-box;
position: relative;
}
.messages-container {
position: absolute;
bottom: 0px;
right: 0px;
left: 0px;
padding: 0px 10px 16px;
box-sizing: border-box;
overflow-y: auto;
max-height: 100%;
display: flex;
flex-direction: column;
padding: 0 12px 16px;
}
.spacer {
flex: 1;
}
.message {
white-space: pre-line;
@ -601,6 +593,9 @@ export class HaAssistChat extends LitElement {
padding: 8px;
border-radius: 15px;
}
.message:last-child {
margin-bottom: 0;
}
@media all and (max-width: 450px), all and (max-height: 500px) {
.message {
@ -619,7 +614,7 @@ export class HaAssistChat extends LitElement {
margin-left: 24px;
margin-inline-start: 24px;
margin-inline-end: initial;
float: var(--float-end);
align-self: flex-end;
text-align: right;
border-bottom-right-radius: 0px;
background-color: var(--chat-background-color-user, var(--primary-color));
@ -631,7 +626,7 @@ export class HaAssistChat extends LitElement {
margin-right: 24px;
margin-inline-end: 24px;
margin-inline-start: initial;
float: var(--float-start);
align-self: flex-start;
border-bottom-left-radius: 0px;
background-color: var(
--chat-background-color-hass,

View File

@ -1,6 +1,5 @@
import type { HomeAssistant } from "../types";
import type { ConversationResult } from "./conversation";
import type { ResolvedMediaSource } from "./media_source";
import type { SpeechMetadata } from "./stt";
export interface AssistPipeline {
@ -53,10 +52,16 @@ interface PipelineRunStartEvent extends PipelineEventBase {
data: {
pipeline: string;
language: string;
conversation_id: string;
runner_data: {
stt_binary_handler_id: number | null;
timeout: number;
};
tts_output?: {
token: string;
url: string;
mime_type: string;
};
};
}
interface PipelineRunEndEvent extends PipelineEventBase {
@ -109,7 +114,7 @@ interface PipelineIntentStartEvent extends PipelineEventBase {
};
}
interface ConversationChatLogAssistantDelta {
export interface ConversationChatLogAssistantDelta {
role: "assistant";
content: string;
tool_calls: {
@ -119,7 +124,7 @@ interface ConversationChatLogAssistantDelta {
}[];
}
interface ConversationChatLogToolResultDelta {
export interface ConversationChatLogToolResultDelta {
role: "tool_result";
agent_id: string;
tool_call_id: string;
@ -156,7 +161,12 @@ interface PipelineTTSStartEvent extends PipelineEventBase {
interface PipelineTTSEndEvent extends PipelineEventBase {
type: "tts-end";
data: {
tts_output: ResolvedMediaSource;
tts_output: {
media_id: string;
token: string;
url: string;
mime_type: string;
};
};
}