Use assist_pipeline in voice command dialog (#16257)

* Remove speech recognition, add basic pipeline support

* Add basic voice support

* cleanup

* only use tts if pipeline supports it

* Update ha-voice-command-dialog.ts

* Fix types

* handle stop during stt

* Revert "Fix types"

This reverts commit 741781e392048d2e29594e388386bebce70ff68e.

* active read only

* Update ha-voice-command-dialog.ts

---------

Co-authored-by: Paul Bottein <paul.bottein@gmail.com>
This commit is contained in:
Bram Kragten 2023-04-22 03:50:30 +02:00 committed by GitHub
parent 85a27e8bb1
commit 1ded47d368
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 340 additions and 206 deletions

View File

@ -1,7 +0,0 @@
export const SpeechRecognition =
window.SpeechRecognition || window.webkitSpeechRecognition;
export const SpeechGrammarList =
window.SpeechGrammarList || window.webkitSpeechGrammarList;
export const SpeechRecognitionEvent =
// @ts-expect-error
window.SpeechRecognitionEvent || window.webkitSpeechRecognitionEvent;

View File

@ -9,7 +9,7 @@ import {
import { customElement, property, state } from "lit/decorators";
import { fireEvent } from "../common/dom/fire_event";
import { stopPropagation } from "../common/dom/stop_propagation";
import { AssistPipeline, fetchAssistPipelines } from "../data/assist_pipeline";
import { AssistPipeline, listAssistPipelines } from "../data/assist_pipeline";
import { HomeAssistant } from "../types";
import "./ha-list-item";
import "./ha-select";
@ -71,7 +71,7 @@ export class HaAssistPipelinePicker extends LitElement {
changedProperties: PropertyValueMap<any> | Map<PropertyKey, unknown>
): void {
super.firstUpdated(changedProperties);
fetchAssistPipelines(this.hass).then((pipelines) => {
listAssistPipelines(this.hass).then((pipelines) => {
this._pipelines = pipelines.pipelines;
this._preferredPipeline = pipelines.preferred_pipeline;
});

View File

@ -210,14 +210,15 @@ export const processEvent = (
return run;
};
export const runAssistPipeline = (
export const runDebugAssistPipeline = (
hass: HomeAssistant,
callback: (event: PipelineRun) => void,
callback: (run: PipelineRun) => void,
options: PipelineRunOptions
) => {
let run: PipelineRun | undefined;
const unsubProm = hass.connection.subscribeMessage<PipelineRunEvent>(
const unsubProm = runAssistPipeline(
hass,
(updateEvent) => {
run = processEvent(run, updateEvent, options);
@ -229,15 +230,22 @@ export const runAssistPipeline = (
callback(run);
}
},
{
...options,
type: "assist_pipeline/run",
}
options
);
return unsubProm;
};
export const runAssistPipeline = (
hass: HomeAssistant,
callback: (event: PipelineRunEvent) => void,
options: PipelineRunOptions
) =>
hass.connection.subscribeMessage<PipelineRunEvent>(callback, {
...options,
type: "assist_pipeline/run",
});
export const listAssistPipelineRuns = (
hass: HomeAssistant,
pipeline_id: string
@ -262,7 +270,7 @@ export const getAssistPipelineRun = (
pipeline_run_id,
});
export const fetchAssistPipelines = (hass: HomeAssistant) =>
export const listAssistPipelines = (hass: HomeAssistant) =>
hass.callWS<{
pipelines: AssistPipeline[];
preferred_pipeline: string | null;
@ -270,6 +278,12 @@ export const fetchAssistPipelines = (hass: HomeAssistant) =>
type: "assist_pipeline/pipeline/list",
});
export const getAssistPipeline = (hass: HomeAssistant, pipeline_id?: string) =>
hass.callWS<AssistPipeline>({
type: "assist_pipeline/pipeline/get",
pipeline_id,
});
export const createAssistPipeline = (
hass: HomeAssistant,
pipeline: AssistPipelineMutableParams

View File

@ -1,4 +1,3 @@
/* eslint-disable lit/prefer-static-styles */
import "@material/mwc-button/mwc-button";
import {
mdiClose,
@ -11,28 +10,28 @@ import {
CSSResultGroup,
html,
LitElement,
PropertyValues,
nothing,
PropertyValues,
} from "lit";
import { customElement, property, query, state } from "lit/decorators";
import { classMap } from "lit/directives/class-map";
import { LocalStorage } from "../../common/decorators/local-storage";
import { fireEvent } from "../../common/dom/fire_event";
import { SpeechRecognition } from "../../common/dom/speech-recognition";
import "../../components/ha-dialog";
import type { HaDialog } from "../../components/ha-dialog";
import "../../components/ha-header-bar";
import "../../components/ha-icon-button";
import "../../components/ha-textfield";
import type { HaTextField } from "../../components/ha-textfield";
import {
AgentInfo,
getAgentInfo,
prepareConversation,
processConversationInput,
} from "../../data/conversation";
AssistPipeline,
getAssistPipeline,
runAssistPipeline,
} from "../../data/assist_pipeline";
import { AgentInfo, getAgentInfo } from "../../data/conversation";
import { haStyleDialog } from "../../resources/styles";
import type { HomeAssistant } from "../../types";
import { AudioRecorder } from "../../util/audio-recorder";
import { documentationUrl } from "../../util/documentation-url";
import { showAlertDialog } from "../generic/show-dialog-box";
interface Message {
who: string;
@ -40,49 +39,53 @@ interface Message {
error?: boolean;
}
interface Results {
transcript: string;
final: boolean;
}
@customElement("ha-voice-command-dialog")
export class HaVoiceCommandDialog extends LitElement {
@property({ attribute: false }) public hass!: HomeAssistant;
@property() public results: Results | null = null;
@state() private _conversation: Message[] = [
{
who: "hass",
text: "",
},
];
@state() private _conversation?: Message[];
@state() private _opened = false;
@LocalStorage("AssistPipelineId", true, false) private _pipelineId?: string;
@state() private _pipeline?: AssistPipeline;
@state() private _agentInfo?: AgentInfo;
@state() private _showSendButton = false;
@query("#scroll-container") private _scrollContainer!: HaDialog;
@query("#scroll-container") private _scrollContainer!: HTMLDivElement;
@query("#message-input") private _messageInput!: HaTextField;
private recognition!: SpeechRecognition;
private _conversationId: string | null = null;
private _audioRecorder?: AudioRecorder;
private _audioBuffer?: Int16Array[];
private _stt_binary_handler_id?: number | null;
public async showDialog(): Promise<void> {
this._conversation = [
{
who: "hass",
text: this.hass.localize("ui.dialogs.voice_command.how_can_i_help"),
},
];
this._opened = true;
this._agentInfo = await getAgentInfo(this.hass);
await this.updateComplete;
this._scrollMessagesBottom();
}
public async closeDialog(): Promise<void> {
this._opened = false;
if (this.recognition) {
this.recognition.abort();
}
this._agentInfo = undefined;
this._conversation = undefined;
this._conversationId = null;
this._audioRecorder?.close();
this._audioRecorder = undefined;
fireEvent(this, "dialog-closed", { dialog: this.localName });
}
@ -90,6 +93,7 @@ export class HaVoiceCommandDialog extends LitElement {
if (!this._opened) {
return nothing;
}
const supportsSTT = this._pipeline?.stt_engine && AudioRecorder.isSupported;
return html`
<ha-dialog
open
@ -123,25 +127,13 @@ export class HaVoiceCommandDialog extends LitElement {
</div>
<div class="messages">
<div class="messages-container" id="scroll-container">
${this._conversation.map(
${this._conversation!.map(
(message) => html`
<div class=${this._computeMessageClasses(message)}>
${message.text}
</div>
`
)}
${this.results
? html`
<div class="message user">
<span
class=${classMap({
interimTranscript: !this.results.final,
})}
>${this.results.transcript}</span
>${!this.results.final ? "…" : ""}
</div>
`
: ""}
</div>
</div>
<div class="input" slot="primaryAction">
@ -166,9 +158,9 @@ export class HaVoiceCommandDialog extends LitElement {
>
</ha-icon-button>
`
: SpeechRecognition
: supportsSTT
? html`
${this.results
${this._audioRecorder?.active
? html`
<div class="bouncer">
<div class="double-bounce1"></div>
@ -205,15 +197,18 @@ export class HaVoiceCommandDialog extends LitElement {
`;
}
protected firstUpdated(changedProps: PropertyValues) {
super.updated(changedProps);
this._conversation = [
{
who: "hass",
text: this.hass.localize("ui.dialogs.voice_command.how_can_i_help"),
},
];
prepareConversation(this.hass, this.hass.language);
protected willUpdate(changedProperties: PropertyValues): void {
if (!this.hasUpdated || changedProperties.has("_pipelineId")) {
this._getPipeline();
}
}
private async _getPipeline() {
this._pipeline = await getAssistPipeline(this.hass, this._pipelineId);
this._agentInfo = await getAgentInfo(
this.hass,
this._pipeline.conversation_engine
);
}
protected updated(changedProps: PropertyValues) {
@ -224,7 +219,7 @@ export class HaVoiceCommandDialog extends LitElement {
}
private _addMessage(message: Message) {
this._conversation = [...this._conversation, message];
this._conversation = [...this._conversation!, message];
}
private _handleKeyUp(ev: KeyboardEvent) {
@ -253,75 +248,7 @@ export class HaVoiceCommandDialog extends LitElement {
}
}
private _initRecognition() {
this.recognition = new SpeechRecognition();
this.recognition.interimResults = true;
this.recognition.lang = this.hass.language;
this.recognition.continuous = false;
this.recognition.addEventListener("start", () => {
this.results = {
final: false,
transcript: "",
};
});
this.recognition.addEventListener("nomatch", () => {
this._addMessage({
who: "user",
text: `<${this.hass.localize(
"ui.dialogs.voice_command.did_not_understand"
)}>`,
error: true,
});
});
this.recognition.addEventListener("error", (event) => {
// eslint-disable-next-line
console.error("Error recognizing text", event);
this.recognition!.abort();
// @ts-ignore
if (event.error !== "aborted" && event.error !== "no-speech") {
const text =
this.results && this.results.transcript
? this.results.transcript
: `<${this.hass.localize(
"ui.dialogs.voice_command.did_not_hear"
)}>`;
this._addMessage({ who: "user", text, error: true });
}
this.results = null;
});
this.recognition.addEventListener("end", () => {
// Already handled by onerror
if (this.results == null) {
return;
}
const text = this.results.transcript;
this.results = null;
if (text) {
this._processText(text);
} else {
this._addMessage({
who: "user",
text: `<${this.hass.localize(
"ui.dialogs.voice_command.did_not_hear"
)}>`,
error: true,
});
}
});
this.recognition.addEventListener("result", (event) => {
const result = event.results[0];
this.results = {
transcript: result[0].transcript,
final: result.isFinal,
};
});
}
private async _processText(text: string) {
if (this.recognition) {
this.recognition.abort();
}
this._addMessage({ who: "user", text });
const message: Message = {
who: "hass",
@ -330,21 +257,33 @@ export class HaVoiceCommandDialog extends LitElement {
// To make sure the answer is placed at the right user text, we add it before we process it
this._addMessage(message);
try {
const response = await processConversationInput(
const unsub = await runAssistPipeline(
this.hass,
text,
this._conversationId,
this.hass.language
(event) => {
if (event.type === "intent-end") {
this._conversationId = event.data.intent_output.conversation_id;
const plain = event.data.intent_output.response.speech?.plain;
if (plain) {
message.text = plain.speech;
}
this.requestUpdate("_conversation");
unsub();
}
if (event.type === "error") {
message.text = event.data.message;
message.error = true;
this.requestUpdate("_conversation");
unsub();
}
},
{
start_stage: "intent",
input: { text },
end_stage: "intent",
pipeline: this._pipelineId,
conversation_id: this._conversationId,
}
);
this._conversationId = response.conversation_id;
const plain = response.response.speech?.plain;
if (plain) {
message.text = plain.speech;
} else {
message.text = "<silence>";
}
this.requestUpdate("_conversation");
} catch {
message.text = this.hass.localize("ui.dialogs.voice_command.error");
message.error = true;
@ -353,37 +292,152 @@ export class HaVoiceCommandDialog extends LitElement {
}
private _toggleListening() {
if (!this.results) {
if (!this._audioRecorder?.active) {
this._startListening();
} else {
this._stopListening();
}
}
private _stopListening() {
if (this.recognition) {
this.recognition.stop();
private async _startListening() {
if (!this._audioRecorder) {
this._audioRecorder = new AudioRecorder((audio) => {
if (this._audioBuffer) {
this._audioBuffer.push(audio);
} else {
this._sendAudioChunk(audio);
}
});
}
this._audioBuffer = [];
const userMessage: Message = {
who: "user",
text: "…",
};
this._audioRecorder.start().then(() => {
this._addMessage(userMessage);
this.requestUpdate("_audioRecorder");
});
const hassMessage: Message = {
who: "hass",
text: "…",
};
// To make sure the answer is placed at the right user text, we add it before we process it
try {
const unsub = await runAssistPipeline(
this.hass,
(event) => {
if (event.type === "run-start") {
this._stt_binary_handler_id =
event.data.runner_data.stt_binary_handler_id;
}
// When we start STT stage, the WS has a binary handler
if (event.type === "stt-start" && this._audioBuffer) {
// Send the buffer over the WS to the STT engine.
for (const buffer of this._audioBuffer) {
this._sendAudioChunk(buffer);
}
this._audioBuffer = undefined;
}
// Stop recording if the server is done with STT stage
if (event.type === "stt-end") {
this._stt_binary_handler_id = undefined;
this._stopListening();
userMessage.text = event.data.stt_output.text;
this.requestUpdate("_conversation");
// To make sure the answer is placed at the right user text, we add it before we process it
this._addMessage(hassMessage);
}
if (event.type === "intent-end") {
this._conversationId = event.data.intent_output.conversation_id;
const plain = event.data.intent_output.response.speech?.plain;
if (plain) {
hassMessage.text = plain.speech;
}
this.requestUpdate("_conversation");
}
if (event.type === "tts-end") {
const url = event.data.tts_output.url;
const audio = new Audio(url);
audio.play();
}
if (event.type === "run-end") {
unsub();
}
if (event.type === "error") {
this._stt_binary_handler_id = undefined;
if (userMessage.text === "…") {
userMessage.text = event.data.message;
userMessage.error = true;
} else {
hassMessage.text = event.data.message;
hassMessage.error = true;
}
this._stopListening();
this.requestUpdate("_conversation");
unsub();
}
},
{
start_stage: "stt",
end_stage: this._pipeline?.tts_engine ? "tts" : "intent",
input: { sample_rate: this._audioRecorder.sampleRate! },
pipeline: this._pipelineId,
conversation_id: this._conversationId,
}
);
} catch (err: any) {
await showAlertDialog(this, {
title: "Error starting pipeline",
text: err.message || err,
});
this._stopListening();
}
}
private _startListening() {
if (!this.recognition) {
this._initRecognition();
private _stopListening() {
this._audioRecorder?.stop();
this.requestUpdate("_audioRecorder");
// We're currently STTing, so finish audio
if (this._stt_binary_handler_id) {
if (this._audioBuffer) {
for (const chunk of this._audioBuffer) {
this._sendAudioChunk(chunk);
}
}
// Send empty message to indicate we're done streaming.
this._sendAudioChunk(new Int16Array());
}
this._audioBuffer = undefined;
}
if (this.results) {
private _sendAudioChunk(chunk: Int16Array) {
this.hass.connection.socket!.binaryType = "arraybuffer";
// eslint-disable-next-line eqeqeq
if (this._stt_binary_handler_id == undefined) {
return;
}
// Turn into 8 bit so we can prefix our handler ID.
const data = new Uint8Array(1 + chunk.length * 2);
data[0] = this._stt_binary_handler_id;
data.set(new Uint8Array(chunk.buffer), 1);
this.results = {
transcript: "",
final: false,
};
this.recognition!.start();
this.hass.connection.socket!.send(data);
}
private _scrollMessagesBottom() {
this._scrollContainer.scrollTo(0, 99999);
const scrollContainer = this._scrollContainer;
if (!scrollContainer) {
return;
}
scrollContainer.scrollTo(0, 99999);
}
private _computeMessageClasses(message: Message) {
@ -512,10 +566,6 @@ export class HaVoiceCommandDialog extends LitElement {
margin-right: 0;
}
.interimTranscript {
color: var(--secondary-text-color);
}
.bouncer {
width: 48px;
height: 48px;

View File

@ -12,7 +12,7 @@ import "../../../components/ha-button";
import {
createAssistPipeline,
deleteAssistPipeline,
fetchAssistPipelines,
listAssistPipelines,
updateAssistPipeline,
AssistPipeline,
setAssistPipelinePreferred,
@ -33,7 +33,7 @@ export class AssistPref extends LitElement {
protected firstUpdated(changedProps: PropertyValues) {
super.firstUpdated(changedProps);
fetchAssistPipelines(this.hass).then((pipelines) => {
listAssistPipelines(this.hass).then((pipelines) => {
this._pipelines = pipelines.pipelines;
this._preferred = pipelines.preferred_pipeline;
});

View File

@ -1,17 +1,17 @@
import { css, html, LitElement, TemplateResult } from "lit";
import { customElement, property, query, state } from "lit/decorators";
import { extractSearchParam } from "../../../../common/url/search-params";
import "../../../../components/ha-assist-pipeline-picker";
import "../../../../components/ha-button";
import "../../../../components/ha-checkbox";
import type { HaCheckbox } from "../../../../components/ha-checkbox";
import "../../../../components/ha-formfield";
import "../../../../components/ha-assist-pipeline-picker";
import "../../../../components/ha-textfield";
import type { HaTextField } from "../../../../components/ha-textfield";
import {
PipelineRun,
PipelineRunOptions,
runAssistPipeline,
runDebugAssistPipeline,
} from "../../../../data/assist_pipeline";
import {
showAlertDialog,
@ -20,6 +20,7 @@ import {
import "../../../../layouts/hass-subpage";
import { haStyle } from "../../../../resources/styles";
import type { HomeAssistant } from "../../../../types";
import { AudioRecorder } from "../../../../util/audio-recorder";
import { fileDownload } from "../../../../util/file_download";
import "./assist-render-pipeline-run";
@ -81,7 +82,13 @@ export class AssistPipelineRunDebug extends LitElement {
<ha-button raised @click=${this._runTextPipeline}>
Run Text Pipeline
</ha-button>
<ha-button raised @click=${this._runAudioPipeline}>
<ha-button
raised
@click=${this._runAudioPipeline}
.disabled=${!window.isSecureContext ||
// @ts-ignore-next-line
!(window.AudioContext || window.webkitAudioContext)}
>
Run Audio Pipeline
</ha-button>
`
@ -173,21 +180,16 @@ export class AssistPipelineRunDebug extends LitElement {
}
private async _runAudioPipeline() {
// @ts-ignore-next-line
const context = new (window.AudioContext || window.webkitAudioContext)();
let stream: MediaStream;
try {
stream = await navigator.mediaDevices.getUserMedia({ audio: true });
} catch (err) {
return;
}
const audioRecorder = new AudioRecorder((data) => {
if (this._audioBuffer) {
this._audioBuffer.push(data);
} else {
this._sendAudioChunk(data);
}
});
await context.audioWorklet.addModule(
new URL("./recorder.worklet.js", import.meta.url)
);
const source = context.createMediaStreamSource(stream);
const recorder = new AudioWorkletNode(context, "recorder.worklet");
this._audioBuffer = [];
audioRecorder.start();
this.hass.connection.socket!.binaryType = "arraybuffer";
@ -195,6 +197,7 @@ export class AssistPipelineRunDebug extends LitElement {
let stopRecording: (() => void) | undefined = () => {
stopRecording = undefined;
audioRecorder.close();
// We're currently STTing, so finish audio
if (run?.stage === "stt" && run.stt!.done === false) {
if (this._audioBuffer) {
@ -206,20 +209,6 @@ export class AssistPipelineRunDebug extends LitElement {
this._sendAudioChunk(new Int16Array());
}
this._audioBuffer = undefined;
stream.getTracks()[0].stop();
context.close();
};
this._audioBuffer = [];
source.connect(recorder).connect(context.destination);
recorder.port.onmessage = (e) => {
if (!stopRecording) {
return;
}
if (this._audioBuffer) {
this._audioBuffer.push(e.data);
} else {
this._sendAudioChunk(e.data);
}
};
await this._doRunPipeline(
@ -260,7 +249,7 @@ export class AssistPipelineRunDebug extends LitElement {
start_stage: "stt",
end_stage: "tts",
input: {
sample_rate: context.sampleRate,
sample_rate: audioRecorder.sampleRate!,
},
}
);
@ -273,7 +262,7 @@ export class AssistPipelineRunDebug extends LitElement {
this._finished = false;
let added = false;
try {
await runAssistPipeline(
await runDebugAssistPipeline(
this.hass,
(updatedRun) => {
if (added) {

View File

@ -0,0 +1,88 @@
export class AudioRecorder {
private _active = false;
private _callback: (data: Int16Array) => void;
private _context: AudioContext | undefined;
private _stream: MediaStream | undefined;
constructor(callback: (data: Int16Array) => void) {
this._callback = callback;
}
public get active() {
return this._active;
}
public get sampleRate() {
return this._context?.sampleRate;
}
public static get isSupported() {
return (
window.isSecureContext &&
// @ts-ignore-next-line
(window.AudioContext || window.webkitAudioContext)
);
}
public async start() {
this._active = true;
if (!this._context || !this._stream) {
await this._createContext();
} else {
this._context.resume();
this._stream.getTracks()[0].enabled = true;
}
if (!this._context || !this._stream) {
this._active = false;
return;
}
const source = this._context.createMediaStreamSource(this._stream);
const recorder = new AudioWorkletNode(this._context, "recorder.worklet");
source.connect(recorder).connect(this._context.destination);
recorder.port.onmessage = (e) => {
if (!this._active) {
return;
}
this._callback(e.data);
};
}
public async stop() {
this._active = false;
if (this._stream) {
this._stream.getTracks()[0].enabled = false;
}
await this._context?.suspend();
}
public close() {
this._active = false;
this._stream?.getTracks()[0].stop();
this._context?.close();
this._stream = undefined;
this._context = undefined;
}
private async _createContext() {
try {
// @ts-ignore-next-line
this._context = new (window.AudioContext || window.webkitAudioContext)();
this._stream = await navigator.mediaDevices.getUserMedia({ audio: true });
} catch (err) {
// eslint-disable-next-line no-console
console.error(err);
return;
}
await this._context.audioWorklet.addModule(
new URL("./recorder.worklet.js", import.meta.url)
);
}
}