From 1ded47d36829ebd55e4f76e353576278c5fc9283 Mon Sep 17 00:00:00 2001 From: Bram Kragten Date: Sat, 22 Apr 2023 03:50:30 +0200 Subject: [PATCH] Use assist_pipeline in voice command dialog (#16257) * Remove speech recognition, add basic pipeline support * Add basic voice support * cleanup * only use tts if pipeline supports it * Update ha-voice-command-dialog.ts * Fix types * handle stop during stt * Revert "Fix types" This reverts commit 741781e392048d2e29594e388386bebce70ff68e. * active read only * Update ha-voice-command-dialog.ts --------- Co-authored-by: Paul Bottein --- src/common/dom/speech-recognition.ts | 7 - src/components/ha-assist-pipeline-picker.ts | 4 +- src/data/assist_pipeline.ts | 30 +- .../ha-voice-command-dialog.ts | 358 ++++++++++-------- .../config/voice-assistants/assist-pref.ts | 4 +- .../debug/assist-pipeline-run-debug.ts | 55 ++- src/util/audio-recorder.ts | 88 +++++ .../debug => util}/recorder.worklet.js | 0 8 files changed, 340 insertions(+), 206 deletions(-) delete mode 100644 src/common/dom/speech-recognition.ts create mode 100644 src/util/audio-recorder.ts rename src/{panels/config/voice-assistants/debug => util}/recorder.worklet.js (100%) diff --git a/src/common/dom/speech-recognition.ts b/src/common/dom/speech-recognition.ts deleted file mode 100644 index 6efec5b8a2..0000000000 --- a/src/common/dom/speech-recognition.ts +++ /dev/null @@ -1,7 +0,0 @@ -export const SpeechRecognition = - window.SpeechRecognition || window.webkitSpeechRecognition; -export const SpeechGrammarList = - window.SpeechGrammarList || window.webkitSpeechGrammarList; -export const SpeechRecognitionEvent = - // @ts-expect-error - window.SpeechRecognitionEvent || window.webkitSpeechRecognitionEvent; diff --git a/src/components/ha-assist-pipeline-picker.ts b/src/components/ha-assist-pipeline-picker.ts index 21a1f9bb00..d5a1ce0b6c 100644 --- a/src/components/ha-assist-pipeline-picker.ts +++ b/src/components/ha-assist-pipeline-picker.ts @@ -9,7 +9,7 @@ import { import { customElement, property, state } from "lit/decorators"; import { fireEvent } from "../common/dom/fire_event"; import { stopPropagation } from "../common/dom/stop_propagation"; -import { AssistPipeline, fetchAssistPipelines } from "../data/assist_pipeline"; +import { AssistPipeline, listAssistPipelines } from "../data/assist_pipeline"; import { HomeAssistant } from "../types"; import "./ha-list-item"; import "./ha-select"; @@ -71,7 +71,7 @@ export class HaAssistPipelinePicker extends LitElement { changedProperties: PropertyValueMap | Map ): void { super.firstUpdated(changedProperties); - fetchAssistPipelines(this.hass).then((pipelines) => { + listAssistPipelines(this.hass).then((pipelines) => { this._pipelines = pipelines.pipelines; this._preferredPipeline = pipelines.preferred_pipeline; }); diff --git a/src/data/assist_pipeline.ts b/src/data/assist_pipeline.ts index c9c214760e..4d699359e3 100644 --- a/src/data/assist_pipeline.ts +++ b/src/data/assist_pipeline.ts @@ -210,14 +210,15 @@ export const processEvent = ( return run; }; -export const runAssistPipeline = ( +export const runDebugAssistPipeline = ( hass: HomeAssistant, - callback: (event: PipelineRun) => void, + callback: (run: PipelineRun) => void, options: PipelineRunOptions ) => { let run: PipelineRun | undefined; - const unsubProm = hass.connection.subscribeMessage( + const unsubProm = runAssistPipeline( + hass, (updateEvent) => { run = processEvent(run, updateEvent, options); @@ -229,15 +230,22 @@ export const runAssistPipeline = ( callback(run); } }, - { - ...options, - type: "assist_pipeline/run", - } + options ); return unsubProm; }; +export const runAssistPipeline = ( + hass: HomeAssistant, + callback: (event: PipelineRunEvent) => void, + options: PipelineRunOptions +) => + hass.connection.subscribeMessage(callback, { + ...options, + type: "assist_pipeline/run", + }); + export const listAssistPipelineRuns = ( hass: HomeAssistant, pipeline_id: string @@ -262,7 +270,7 @@ export const getAssistPipelineRun = ( pipeline_run_id, }); -export const fetchAssistPipelines = (hass: HomeAssistant) => +export const listAssistPipelines = (hass: HomeAssistant) => hass.callWS<{ pipelines: AssistPipeline[]; preferred_pipeline: string | null; @@ -270,6 +278,12 @@ export const fetchAssistPipelines = (hass: HomeAssistant) => type: "assist_pipeline/pipeline/list", }); +export const getAssistPipeline = (hass: HomeAssistant, pipeline_id?: string) => + hass.callWS({ + type: "assist_pipeline/pipeline/get", + pipeline_id, + }); + export const createAssistPipeline = ( hass: HomeAssistant, pipeline: AssistPipelineMutableParams diff --git a/src/dialogs/voice-command-dialog/ha-voice-command-dialog.ts b/src/dialogs/voice-command-dialog/ha-voice-command-dialog.ts index b32d0d40e5..ebcbf185d7 100644 --- a/src/dialogs/voice-command-dialog/ha-voice-command-dialog.ts +++ b/src/dialogs/voice-command-dialog/ha-voice-command-dialog.ts @@ -1,4 +1,3 @@ -/* eslint-disable lit/prefer-static-styles */ import "@material/mwc-button/mwc-button"; import { mdiClose, @@ -11,28 +10,28 @@ import { CSSResultGroup, html, LitElement, - PropertyValues, nothing, + PropertyValues, } from "lit"; import { customElement, property, query, state } from "lit/decorators"; -import { classMap } from "lit/directives/class-map"; +import { LocalStorage } from "../../common/decorators/local-storage"; import { fireEvent } from "../../common/dom/fire_event"; -import { SpeechRecognition } from "../../common/dom/speech-recognition"; import "../../components/ha-dialog"; -import type { HaDialog } from "../../components/ha-dialog"; import "../../components/ha-header-bar"; import "../../components/ha-icon-button"; import "../../components/ha-textfield"; import type { HaTextField } from "../../components/ha-textfield"; import { - AgentInfo, - getAgentInfo, - prepareConversation, - processConversationInput, -} from "../../data/conversation"; + AssistPipeline, + getAssistPipeline, + runAssistPipeline, +} from "../../data/assist_pipeline"; +import { AgentInfo, getAgentInfo } from "../../data/conversation"; import { haStyleDialog } from "../../resources/styles"; import type { HomeAssistant } from "../../types"; +import { AudioRecorder } from "../../util/audio-recorder"; import { documentationUrl } from "../../util/documentation-url"; +import { showAlertDialog } from "../generic/show-dialog-box"; interface Message { who: string; @@ -40,49 +39,53 @@ interface Message { error?: boolean; } -interface Results { - transcript: string; - final: boolean; -} - @customElement("ha-voice-command-dialog") export class HaVoiceCommandDialog extends LitElement { @property({ attribute: false }) public hass!: HomeAssistant; - @property() public results: Results | null = null; - - @state() private _conversation: Message[] = [ - { - who: "hass", - text: "", - }, - ]; + @state() private _conversation?: Message[]; @state() private _opened = false; + @LocalStorage("AssistPipelineId", true, false) private _pipelineId?: string; + + @state() private _pipeline?: AssistPipeline; + @state() private _agentInfo?: AgentInfo; @state() private _showSendButton = false; - @query("#scroll-container") private _scrollContainer!: HaDialog; + @query("#scroll-container") private _scrollContainer!: HTMLDivElement; @query("#message-input") private _messageInput!: HaTextField; - private recognition!: SpeechRecognition; - private _conversationId: string | null = null; + private _audioRecorder?: AudioRecorder; + + private _audioBuffer?: Int16Array[]; + + private _stt_binary_handler_id?: number | null; + public async showDialog(): Promise { + this._conversation = [ + { + who: "hass", + text: this.hass.localize("ui.dialogs.voice_command.how_can_i_help"), + }, + ]; this._opened = true; - this._agentInfo = await getAgentInfo(this.hass); + await this.updateComplete; this._scrollMessagesBottom(); } public async closeDialog(): Promise { this._opened = false; - if (this.recognition) { - this.recognition.abort(); - } + this._agentInfo = undefined; + this._conversation = undefined; + this._conversationId = null; + this._audioRecorder?.close(); + this._audioRecorder = undefined; fireEvent(this, "dialog-closed", { dialog: this.localName }); } @@ -90,6 +93,7 @@ export class HaVoiceCommandDialog extends LitElement { if (!this._opened) { return nothing; } + const supportsSTT = this._pipeline?.stt_engine && AudioRecorder.isSupported; return html`
- ${this._conversation.map( + ${this._conversation!.map( (message) => html`
${message.text}
` )} - ${this.results - ? html` -
- ${this.results.transcript}${!this.results.final ? "…" : ""} -
- ` - : ""}
@@ -166,9 +158,9 @@ export class HaVoiceCommandDialog extends LitElement { > ` - : SpeechRecognition + : supportsSTT ? html` - ${this.results + ${this._audioRecorder?.active ? html`
@@ -205,15 +197,18 @@ export class HaVoiceCommandDialog extends LitElement { `; } - protected firstUpdated(changedProps: PropertyValues) { - super.updated(changedProps); - this._conversation = [ - { - who: "hass", - text: this.hass.localize("ui.dialogs.voice_command.how_can_i_help"), - }, - ]; - prepareConversation(this.hass, this.hass.language); + protected willUpdate(changedProperties: PropertyValues): void { + if (!this.hasUpdated || changedProperties.has("_pipelineId")) { + this._getPipeline(); + } + } + + private async _getPipeline() { + this._pipeline = await getAssistPipeline(this.hass, this._pipelineId); + this._agentInfo = await getAgentInfo( + this.hass, + this._pipeline.conversation_engine + ); } protected updated(changedProps: PropertyValues) { @@ -224,7 +219,7 @@ export class HaVoiceCommandDialog extends LitElement { } private _addMessage(message: Message) { - this._conversation = [...this._conversation, message]; + this._conversation = [...this._conversation!, message]; } private _handleKeyUp(ev: KeyboardEvent) { @@ -253,75 +248,7 @@ export class HaVoiceCommandDialog extends LitElement { } } - private _initRecognition() { - this.recognition = new SpeechRecognition(); - this.recognition.interimResults = true; - this.recognition.lang = this.hass.language; - this.recognition.continuous = false; - - this.recognition.addEventListener("start", () => { - this.results = { - final: false, - transcript: "", - }; - }); - this.recognition.addEventListener("nomatch", () => { - this._addMessage({ - who: "user", - text: `<${this.hass.localize( - "ui.dialogs.voice_command.did_not_understand" - )}>`, - error: true, - }); - }); - this.recognition.addEventListener("error", (event) => { - // eslint-disable-next-line - console.error("Error recognizing text", event); - this.recognition!.abort(); - // @ts-ignore - if (event.error !== "aborted" && event.error !== "no-speech") { - const text = - this.results && this.results.transcript - ? this.results.transcript - : `<${this.hass.localize( - "ui.dialogs.voice_command.did_not_hear" - )}>`; - this._addMessage({ who: "user", text, error: true }); - } - this.results = null; - }); - this.recognition.addEventListener("end", () => { - // Already handled by onerror - if (this.results == null) { - return; - } - const text = this.results.transcript; - this.results = null; - if (text) { - this._processText(text); - } else { - this._addMessage({ - who: "user", - text: `<${this.hass.localize( - "ui.dialogs.voice_command.did_not_hear" - )}>`, - error: true, - }); - } - }); - this.recognition.addEventListener("result", (event) => { - const result = event.results[0]; - this.results = { - transcript: result[0].transcript, - final: result.isFinal, - }; - }); - } - private async _processText(text: string) { - if (this.recognition) { - this.recognition.abort(); - } this._addMessage({ who: "user", text }); const message: Message = { who: "hass", @@ -330,21 +257,33 @@ export class HaVoiceCommandDialog extends LitElement { // To make sure the answer is placed at the right user text, we add it before we process it this._addMessage(message); try { - const response = await processConversationInput( + const unsub = await runAssistPipeline( this.hass, - text, - this._conversationId, - this.hass.language + (event) => { + if (event.type === "intent-end") { + this._conversationId = event.data.intent_output.conversation_id; + const plain = event.data.intent_output.response.speech?.plain; + if (plain) { + message.text = plain.speech; + } + this.requestUpdate("_conversation"); + unsub(); + } + if (event.type === "error") { + message.text = event.data.message; + message.error = true; + this.requestUpdate("_conversation"); + unsub(); + } + }, + { + start_stage: "intent", + input: { text }, + end_stage: "intent", + pipeline: this._pipelineId, + conversation_id: this._conversationId, + } ); - this._conversationId = response.conversation_id; - const plain = response.response.speech?.plain; - if (plain) { - message.text = plain.speech; - } else { - message.text = ""; - } - - this.requestUpdate("_conversation"); } catch { message.text = this.hass.localize("ui.dialogs.voice_command.error"); message.error = true; @@ -353,37 +292,152 @@ export class HaVoiceCommandDialog extends LitElement { } private _toggleListening() { - if (!this.results) { + if (!this._audioRecorder?.active) { this._startListening(); } else { this._stopListening(); } } - private _stopListening() { - if (this.recognition) { - this.recognition.stop(); + private async _startListening() { + if (!this._audioRecorder) { + this._audioRecorder = new AudioRecorder((audio) => { + if (this._audioBuffer) { + this._audioBuffer.push(audio); + } else { + this._sendAudioChunk(audio); + } + }); + } + this._audioBuffer = []; + const userMessage: Message = { + who: "user", + text: "…", + }; + this._audioRecorder.start().then(() => { + this._addMessage(userMessage); + this.requestUpdate("_audioRecorder"); + }); + const hassMessage: Message = { + who: "hass", + text: "…", + }; + // To make sure the answer is placed at the right user text, we add it before we process it + try { + const unsub = await runAssistPipeline( + this.hass, + (event) => { + if (event.type === "run-start") { + this._stt_binary_handler_id = + event.data.runner_data.stt_binary_handler_id; + } + + // When we start STT stage, the WS has a binary handler + if (event.type === "stt-start" && this._audioBuffer) { + // Send the buffer over the WS to the STT engine. + for (const buffer of this._audioBuffer) { + this._sendAudioChunk(buffer); + } + this._audioBuffer = undefined; + } + + // Stop recording if the server is done with STT stage + if (event.type === "stt-end") { + this._stt_binary_handler_id = undefined; + this._stopListening(); + userMessage.text = event.data.stt_output.text; + this.requestUpdate("_conversation"); + // To make sure the answer is placed at the right user text, we add it before we process it + this._addMessage(hassMessage); + } + + if (event.type === "intent-end") { + this._conversationId = event.data.intent_output.conversation_id; + const plain = event.data.intent_output.response.speech?.plain; + if (plain) { + hassMessage.text = plain.speech; + } + this.requestUpdate("_conversation"); + } + + if (event.type === "tts-end") { + const url = event.data.tts_output.url; + const audio = new Audio(url); + audio.play(); + } + + if (event.type === "run-end") { + unsub(); + } + + if (event.type === "error") { + this._stt_binary_handler_id = undefined; + if (userMessage.text === "…") { + userMessage.text = event.data.message; + userMessage.error = true; + } else { + hassMessage.text = event.data.message; + hassMessage.error = true; + } + this._stopListening(); + this.requestUpdate("_conversation"); + unsub(); + } + }, + { + start_stage: "stt", + end_stage: this._pipeline?.tts_engine ? "tts" : "intent", + input: { sample_rate: this._audioRecorder.sampleRate! }, + pipeline: this._pipelineId, + conversation_id: this._conversationId, + } + ); + } catch (err: any) { + await showAlertDialog(this, { + title: "Error starting pipeline", + text: err.message || err, + }); + this._stopListening(); } } - private _startListening() { - if (!this.recognition) { - this._initRecognition(); + private _stopListening() { + this._audioRecorder?.stop(); + this.requestUpdate("_audioRecorder"); + // We're currently STTing, so finish audio + if (this._stt_binary_handler_id) { + if (this._audioBuffer) { + for (const chunk of this._audioBuffer) { + this._sendAudioChunk(chunk); + } + } + // Send empty message to indicate we're done streaming. + this._sendAudioChunk(new Int16Array()); } + this._audioBuffer = undefined; + } - if (this.results) { + private _sendAudioChunk(chunk: Int16Array) { + this.hass.connection.socket!.binaryType = "arraybuffer"; + + // eslint-disable-next-line eqeqeq + if (this._stt_binary_handler_id == undefined) { return; } + // Turn into 8 bit so we can prefix our handler ID. + const data = new Uint8Array(1 + chunk.length * 2); + data[0] = this._stt_binary_handler_id; + data.set(new Uint8Array(chunk.buffer), 1); - this.results = { - transcript: "", - final: false, - }; - this.recognition!.start(); + this.hass.connection.socket!.send(data); } private _scrollMessagesBottom() { - this._scrollContainer.scrollTo(0, 99999); + const scrollContainer = this._scrollContainer; + if (!scrollContainer) { + return; + } + scrollContainer.scrollTo(0, 99999); } private _computeMessageClasses(message: Message) { @@ -512,10 +566,6 @@ export class HaVoiceCommandDialog extends LitElement { margin-right: 0; } - .interimTranscript { - color: var(--secondary-text-color); - } - .bouncer { width: 48px; height: 48px; diff --git a/src/panels/config/voice-assistants/assist-pref.ts b/src/panels/config/voice-assistants/assist-pref.ts index bd26b9ca03..7ea9a3032c 100644 --- a/src/panels/config/voice-assistants/assist-pref.ts +++ b/src/panels/config/voice-assistants/assist-pref.ts @@ -12,7 +12,7 @@ import "../../../components/ha-button"; import { createAssistPipeline, deleteAssistPipeline, - fetchAssistPipelines, + listAssistPipelines, updateAssistPipeline, AssistPipeline, setAssistPipelinePreferred, @@ -33,7 +33,7 @@ export class AssistPref extends LitElement { protected firstUpdated(changedProps: PropertyValues) { super.firstUpdated(changedProps); - fetchAssistPipelines(this.hass).then((pipelines) => { + listAssistPipelines(this.hass).then((pipelines) => { this._pipelines = pipelines.pipelines; this._preferred = pipelines.preferred_pipeline; }); diff --git a/src/panels/config/voice-assistants/debug/assist-pipeline-run-debug.ts b/src/panels/config/voice-assistants/debug/assist-pipeline-run-debug.ts index a597973bed..c019bd76ca 100644 --- a/src/panels/config/voice-assistants/debug/assist-pipeline-run-debug.ts +++ b/src/panels/config/voice-assistants/debug/assist-pipeline-run-debug.ts @@ -1,17 +1,17 @@ import { css, html, LitElement, TemplateResult } from "lit"; import { customElement, property, query, state } from "lit/decorators"; import { extractSearchParam } from "../../../../common/url/search-params"; +import "../../../../components/ha-assist-pipeline-picker"; import "../../../../components/ha-button"; import "../../../../components/ha-checkbox"; import type { HaCheckbox } from "../../../../components/ha-checkbox"; import "../../../../components/ha-formfield"; -import "../../../../components/ha-assist-pipeline-picker"; import "../../../../components/ha-textfield"; import type { HaTextField } from "../../../../components/ha-textfield"; import { PipelineRun, PipelineRunOptions, - runAssistPipeline, + runDebugAssistPipeline, } from "../../../../data/assist_pipeline"; import { showAlertDialog, @@ -20,6 +20,7 @@ import { import "../../../../layouts/hass-subpage"; import { haStyle } from "../../../../resources/styles"; import type { HomeAssistant } from "../../../../types"; +import { AudioRecorder } from "../../../../util/audio-recorder"; import { fileDownload } from "../../../../util/file_download"; import "./assist-render-pipeline-run"; @@ -81,7 +82,13 @@ export class AssistPipelineRunDebug extends LitElement { Run Text Pipeline - + Run Audio Pipeline ` @@ -173,21 +180,16 @@ export class AssistPipelineRunDebug extends LitElement { } private async _runAudioPipeline() { - // @ts-ignore-next-line - const context = new (window.AudioContext || window.webkitAudioContext)(); - let stream: MediaStream; - try { - stream = await navigator.mediaDevices.getUserMedia({ audio: true }); - } catch (err) { - return; - } + const audioRecorder = new AudioRecorder((data) => { + if (this._audioBuffer) { + this._audioBuffer.push(data); + } else { + this._sendAudioChunk(data); + } + }); - await context.audioWorklet.addModule( - new URL("./recorder.worklet.js", import.meta.url) - ); - - const source = context.createMediaStreamSource(stream); - const recorder = new AudioWorkletNode(context, "recorder.worklet"); + this._audioBuffer = []; + audioRecorder.start(); this.hass.connection.socket!.binaryType = "arraybuffer"; @@ -195,6 +197,7 @@ export class AssistPipelineRunDebug extends LitElement { let stopRecording: (() => void) | undefined = () => { stopRecording = undefined; + audioRecorder.close(); // We're currently STTing, so finish audio if (run?.stage === "stt" && run.stt!.done === false) { if (this._audioBuffer) { @@ -206,20 +209,6 @@ export class AssistPipelineRunDebug extends LitElement { this._sendAudioChunk(new Int16Array()); } this._audioBuffer = undefined; - stream.getTracks()[0].stop(); - context.close(); - }; - this._audioBuffer = []; - source.connect(recorder).connect(context.destination); - recorder.port.onmessage = (e) => { - if (!stopRecording) { - return; - } - if (this._audioBuffer) { - this._audioBuffer.push(e.data); - } else { - this._sendAudioChunk(e.data); - } }; await this._doRunPipeline( @@ -260,7 +249,7 @@ export class AssistPipelineRunDebug extends LitElement { start_stage: "stt", end_stage: "tts", input: { - sample_rate: context.sampleRate, + sample_rate: audioRecorder.sampleRate!, }, } ); @@ -273,7 +262,7 @@ export class AssistPipelineRunDebug extends LitElement { this._finished = false; let added = false; try { - await runAssistPipeline( + await runDebugAssistPipeline( this.hass, (updatedRun) => { if (added) { diff --git a/src/util/audio-recorder.ts b/src/util/audio-recorder.ts new file mode 100644 index 0000000000..4a4c2c0088 --- /dev/null +++ b/src/util/audio-recorder.ts @@ -0,0 +1,88 @@ +export class AudioRecorder { + private _active = false; + + private _callback: (data: Int16Array) => void; + + private _context: AudioContext | undefined; + + private _stream: MediaStream | undefined; + + constructor(callback: (data: Int16Array) => void) { + this._callback = callback; + } + + public get active() { + return this._active; + } + + public get sampleRate() { + return this._context?.sampleRate; + } + + public static get isSupported() { + return ( + window.isSecureContext && + // @ts-ignore-next-line + (window.AudioContext || window.webkitAudioContext) + ); + } + + public async start() { + this._active = true; + + if (!this._context || !this._stream) { + await this._createContext(); + } else { + this._context.resume(); + this._stream.getTracks()[0].enabled = true; + } + + if (!this._context || !this._stream) { + this._active = false; + return; + } + + const source = this._context.createMediaStreamSource(this._stream); + const recorder = new AudioWorkletNode(this._context, "recorder.worklet"); + + source.connect(recorder).connect(this._context.destination); + recorder.port.onmessage = (e) => { + if (!this._active) { + return; + } + this._callback(e.data); + }; + } + + public async stop() { + this._active = false; + if (this._stream) { + this._stream.getTracks()[0].enabled = false; + } + await this._context?.suspend(); + } + + public close() { + this._active = false; + this._stream?.getTracks()[0].stop(); + this._context?.close(); + this._stream = undefined; + this._context = undefined; + } + + private async _createContext() { + try { + // @ts-ignore-next-line + this._context = new (window.AudioContext || window.webkitAudioContext)(); + this._stream = await navigator.mediaDevices.getUserMedia({ audio: true }); + } catch (err) { + // eslint-disable-next-line no-console + console.error(err); + return; + } + + await this._context.audioWorklet.addModule( + new URL("./recorder.worklet.js", import.meta.url) + ); + } +} diff --git a/src/panels/config/voice-assistants/debug/recorder.worklet.js b/src/util/recorder.worklet.js similarity index 100% rename from src/panels/config/voice-assistants/debug/recorder.worklet.js rename to src/util/recorder.worklet.js