diff --git a/src/data/voice_assistant.ts b/src/data/voice_assistant.ts index c332e19267..9b5ae8f65d 100644 --- a/src/data/voice_assistant.ts +++ b/src/data/voice_assistant.ts @@ -84,7 +84,7 @@ type PipelineRunEvent = | PipelineTTSStartEvent | PipelineTTSEndEvent; -interface PipelineRunOptions { +export interface PipelineRunOptions { start_stage: "stt" | "intent" | "tts"; end_stage: "stt" | "intent" | "tts"; language?: string; @@ -99,13 +99,15 @@ export interface PipelineRun { stage: "ready" | "stt" | "intent" | "tts" | "done" | "error"; run: PipelineRunStartEvent["data"]; error?: PipelineErrorEvent["data"]; - stt?: PipelineSTTStartEvent["data"] & Partial; + stt?: PipelineSTTStartEvent["data"] & + Partial & { done: boolean }; intent?: PipelineIntentStartEvent["data"] & - Partial; - tts?: PipelineTTSStartEvent["data"] & Partial; + Partial & { done: boolean }; + tts?: PipelineTTSStartEvent["data"] & + Partial & { done: boolean }; } -export const runPipelineFromText = ( +export const runVoiceAssistantPipeline = ( hass: HomeAssistant, callback: (event: PipelineRun) => void, options: PipelineRunOptions @@ -139,17 +141,38 @@ export const runPipelineFromText = ( } if (updateEvent.type === "stt-start") { - run = { ...run, stage: "stt", stt: updateEvent.data }; + run = { + ...run, + stage: "stt", + stt: { ...updateEvent.data, done: false }, + }; } else if (updateEvent.type === "stt-end") { - run = { ...run, stt: { ...run.stt!, ...updateEvent.data } }; + run = { + ...run, + stt: { ...run.stt!, ...updateEvent.data, done: true }, + }; } else if (updateEvent.type === "intent-start") { - run = { ...run, stage: "intent", intent: updateEvent.data }; + run = { + ...run, + stage: "intent", + intent: { ...updateEvent.data, done: false }, + }; } else if (updateEvent.type === "intent-end") { - run = { ...run, intent: { ...run.intent!, ...updateEvent.data } }; + run = { + ...run, + intent: { ...run.intent!, ...updateEvent.data, done: true }, + }; } else if (updateEvent.type === "tts-start") { - run = { ...run, stage: "tts", tts: updateEvent.data }; + run = { + ...run, + stage: "tts", + tts: { ...updateEvent.data, done: false }, + }; } else if (updateEvent.type === "tts-end") { - run = { ...run, tts: { ...run.tts!, ...updateEvent.data } }; + run = { + ...run, + tts: { ...run.tts!, ...updateEvent.data, done: true }, + }; } else if (updateEvent.type === "run-end") { run = { ...run, stage: "done" }; unsubProm.then((unsub) => unsub()); diff --git a/src/panels/config/integrations/integration-panels/voice_assistant/assist/assist-pipeline-debug.ts b/src/panels/config/integrations/integration-panels/voice_assistant/assist/assist-pipeline-debug.ts index c0dcd3d5af..49906ad2a9 100644 --- a/src/panels/config/integrations/integration-panels/voice_assistant/assist/assist-pipeline-debug.ts +++ b/src/panels/config/integrations/integration-panels/voice_assistant/assist/assist-pipeline-debug.ts @@ -1,20 +1,25 @@ -import { css, html, LitElement, PropertyValues, TemplateResult } from "lit"; +import { css, html, LitElement, TemplateResult } from "lit"; import { customElement, property, query, state } from "lit/decorators"; import "../../../../../../components/ha-button"; import { PipelineRun, - runPipelineFromText, + PipelineRunOptions, + runVoiceAssistantPipeline, } from "../../../../../../data/voice_assistant"; import "../../../../../../layouts/hass-subpage"; import "../../../../../../components/ha-formfield"; import "../../../../../../components/ha-checkbox"; import { haStyle } from "../../../../../../resources/styles"; import type { HomeAssistant } from "../../../../../../types"; -import { showPromptDialog } from "../../../../../../dialogs/generic/show-dialog-box"; +import { + showAlertDialog, + showPromptDialog, +} from "../../../../../../dialogs/generic/show-dialog-box"; import "./assist-render-pipeline-run"; import type { HaCheckbox } from "../../../../../../components/ha-checkbox"; import type { HaTextField } from "../../../../../../components/ha-textfield"; import "../../../../../../components/ha-textfield"; +import { fileDownload } from "../../../../../../util/file_download"; @customElement("assist-pipeline-debug") export class AssistPipelineDebug extends LitElement { @@ -24,8 +29,6 @@ export class AssistPipelineDebug extends LitElement { @state() private _pipelineRuns: PipelineRun[] = []; - @state() private _stopRecording?: () => void; - @query("#continue-conversation") private _continueConversationCheckbox!: HaCheckbox; @@ -36,6 +39,8 @@ export class AssistPipelineDebug extends LitElement { @state() private _finished = false; + @state() private _languageOverride?: string; + protected render(): TemplateResult { return html` Clear + + Download + ` - : ""} + : html` + + Set Language + + `}
@@ -81,6 +96,12 @@ export class AssistPipelineDebug extends LitElement { Send ` + : this._finished + ? html` + + Continue talking + + ` : html` { - if (this._continueConversationCheckbox.checked) { - this._runAudioPipeline(); - } else { - this._finished = true; - } - }); - audio.play(); - } else if (currentRun.stage === "error") { - this._finished = true; - } - } - private get conversationId(): string | null { return this._pipelineRuns.length === 0 ? null @@ -177,26 +151,19 @@ export class AssistPipelineDebug extends LitElement { return; } - let added = false; - runPipelineFromText( - this.hass, + await this._doRunPipeline( (run) => { - if (textfield && ["done", "error"].includes(run.stage)) { - textfield.value = ""; - } - - if (added) { - this._pipelineRuns = [run, ...this._pipelineRuns.slice(1)]; - } else { - this._pipelineRuns = [run, ...this._pipelineRuns]; - added = true; + if (["done", "error"].includes(run.stage)) { + this._finished = true; + if (textfield) { + textfield.value = ""; + } } }, { start_stage: "intent", end_stage: "intent", input: { text }, - conversation_id: this.conversationId, } ); } @@ -204,7 +171,13 @@ export class AssistPipelineDebug extends LitElement { private async _runAudioPipeline() { // @ts-ignore-next-line const context = new (window.AudioContext || window.webkitAudioContext)(); - const stream = await navigator.mediaDevices.getUserMedia({ audio: true }); + let stream: MediaStream; + try { + stream = await navigator.mediaDevices.getUserMedia({ audio: true }); + } catch (err) { + return; + } + await context.audioWorklet.addModule( new URL("./recorder.worklet.js", import.meta.url) ); @@ -213,47 +186,111 @@ export class AssistPipelineDebug extends LitElement { const recorder = new AudioWorkletNode(context, "recorder.worklet"); this.hass.connection.socket!.binaryType = "arraybuffer"; - this._stopRecording = () => { + + let run: PipelineRun | undefined; + + let stopRecording: (() => void) | undefined = () => { + stopRecording = undefined; + // We're currently STTing, so finish audio + if (run?.stage === "stt" && run.stt!.done === false) { + if (this._audioBuffer) { + for (const chunk of this._audioBuffer) { + this._sendAudioChunk(chunk); + } + } + // Send empty message to indicate we're done streaming. + this._sendAudioChunk(new Int16Array()); + } + this._audioBuffer = undefined; stream.getTracks()[0].stop(); context.close(); - this._stopRecording = undefined; - this._audioBuffer = undefined; - // Send empty message to indicate we're done streaming. - this._sendAudioChunk(new Int16Array()); }; this._audioBuffer = []; source.connect(recorder).connect(context.destination); recorder.port.onmessage = (e) => { + if (!stopRecording) { + return; + } if (this._audioBuffer) { this._audioBuffer.push(e.data); - return; + } else { + this._sendAudioChunk(e.data); } - if (this._pipelineRuns[0].stage !== "stt") { - return; - } - this._sendAudioChunk(e.data); }; - this._finished = false; - let added = false; - runPipelineFromText( - this.hass, - (run) => { - if (added) { - this._pipelineRuns = [run, ...this._pipelineRuns.slice(1)]; - } else { - this._pipelineRuns = [run, ...this._pipelineRuns]; - added = true; + await this._doRunPipeline( + (updatedRun) => { + run = updatedRun; + + // When we start STT stage, the WS has a binary handler + if (updatedRun.stage === "stt" && this._audioBuffer) { + // Send the buffer over the WS to the STT engine. + for (const buffer of this._audioBuffer) { + this._sendAudioChunk(buffer); + } + this._audioBuffer = undefined; + } + + // Stop recording if the server is done with STT stage + if (!["ready", "stt"].includes(updatedRun.stage) && stopRecording) { + stopRecording(); + } + + // Play audio when we're done. + if (updatedRun.stage === "done") { + const url = updatedRun.tts!.tts_output!.url; + const audio = new Audio(url); + audio.addEventListener("ended", () => { + if (this._continueConversationCheckbox.checked) { + this._runAudioPipeline(); + } else { + this._finished = true; + } + }); + audio.play(); + } else if (updatedRun.stage === "error") { + this._finished = true; } }, { start_stage: "stt", end_stage: "tts", - conversation_id: this.conversationId, } ); } + private async _doRunPipeline( + callback: (event: PipelineRun) => void, + options: PipelineRunOptions + ) { + this._finished = false; + let added = false; + try { + await runVoiceAssistantPipeline( + this.hass, + (updatedRun) => { + if (added) { + this._pipelineRuns = [updatedRun, ...this._pipelineRuns.slice(1)]; + } else { + this._pipelineRuns = [updatedRun, ...this._pipelineRuns]; + added = true; + } + callback(updatedRun); + }, + { + ...options, + language: this._languageOverride, + conversation_id: this.conversationId, + } + ); + } catch (err: any) { + await showAlertDialog(this, { + title: "Error starting pipeline", + text: err.message || err, + }); + } + } + private _sendAudioChunk(chunk: Int16Array) { // Turn into 8 bit so we can prefix our handler ID. const data = new Uint8Array(1 + chunk.length * 2); @@ -273,6 +310,27 @@ export class AssistPipelineDebug extends LitElement { this._pipelineRuns = []; } + private _downloadConversation() { + fileDownload( + `data:text/plain;charset=utf-8,${encodeURIComponent( + JSON.stringify(this._pipelineRuns, null, 2) + )}`, + `conversation.json` + ); + } + + private async _setLanguage() { + const language = await showPromptDialog(this, { + title: "Language override", + inputLabel: "Language", + inputType: "text", + confirmText: "Set", + }); + if (language) { + this._languageOverride = language; + } + } + static styles = [ haStyle, css` diff --git a/src/panels/config/integrations/integration-panels/voice_assistant/assist/assist-render-pipeline-run.ts b/src/panels/config/integrations/integration-panels/voice_assistant/assist/assist-render-pipeline-run.ts index f9940e0cb2..596b7f7560 100644 --- a/src/panels/config/integrations/integration-panels/voice_assistant/assist/assist-render-pipeline-run.ts +++ b/src/panels/config/integrations/integration-panels/voice_assistant/assist/assist-render-pipeline-run.ts @@ -50,9 +50,11 @@ const maybeRenderError = ( return ""; } - return html` - ${run.error!.message} (${run.error!.code}) - `; + return html` + + ${run.error!.message} (${run.error!.code}) + + `; }; const renderProgress = ( @@ -76,10 +78,9 @@ const renderProgress = ( } if (!finishEvent) { - return html``; + return html` + + `; } const duration = @@ -109,7 +110,7 @@ const dataMinusKeysRender = ( const result = {}; let render = false; for (const key in data) { - if (key in keys) { + if (key in keys || key === "done") { continue; } render = true;