diff --git a/src/data/ai_task.ts b/src/data/ai_task.ts index 451fced583..c89f6607fe 100644 --- a/src/data/ai_task.ts +++ b/src/data/ai_task.ts @@ -1,14 +1,23 @@ import type { HomeAssistant } from "../types"; +import type { Selector } from "./selector"; export interface AITaskPreferences { gen_data_entity_id: string | null; } -export interface GenDataTaskResult { +export interface GenDataTaskResult { conversation_id: string; - data: string; + data: T; } +export interface AITaskStructureField { + description?: string; + required?: boolean; + selector: Selector; +} + +export type AITaskStructure = Record; + export const fetchAITaskPreferences = (hass: HomeAssistant) => hass.callWS({ type: "ai_task/preferences/get", @@ -23,15 +32,16 @@ export const saveAITaskPreferences = ( ...preferences, }); -export const generateDataAITask = async ( +export const generateDataAITask = async ( hass: HomeAssistant, task: { task_name: string; entity_id?: string; instructions: string; + structure?: AITaskStructure; } -): Promise => { - const result = await hass.callService( +): Promise> => { + const result = await hass.callService>( "ai_task", "generate_data", task, diff --git a/src/panels/config/automation/automation-save-dialog/dialog-automation-save.ts b/src/panels/config/automation/automation-save-dialog/dialog-automation-save.ts index 2ba4069d0d..6500646071 100644 --- a/src/panels/config/automation/automation-save-dialog/dialog-automation-save.ts +++ b/src/panels/config/automation/automation-save-dialog/dialog-automation-save.ts @@ -30,6 +30,9 @@ import { generateDataAITask, } from "../../../../data/ai_task"; import { isComponentLoaded } from "../../../../common/config/is_component_loaded"; +import { computeStateDomain } from "../../../../common/entity/compute_state_domain"; +import { subscribeOne } from "../../../../common/util/subscribe-one"; +import { subscribeLabelRegistry } from "../../../../data/label_registry"; @customElement("ha-dialog-automation-save") class DialogAutomationSave extends LitElement implements HassDialog { @@ -75,7 +78,7 @@ class DialogAutomationSave extends LitElement implements HassDialog { this._entryUpdates.category ? "category" : "", this._entryUpdates.labels.length > 0 ? "labels" : "", this._entryUpdates.area ? "area" : "", - ]; + ].filter(Boolean); } public closeDialog() { @@ -346,17 +349,121 @@ class DialogAutomationSave extends LitElement implements HassDialog { } private async _suggest() { - const result = await generateDataAITask(this.hass, { - task_name: "frontend:automation:save", - instructions: `Suggest one name for the following Home Assistant automation. -Your answer should only contain the name, without any additional text or formatting. -The name should be relevant to the automation's purpose and should not exceed 50 characters. -The name should be short, descriptive, sentence case, and written in the language ${this.hass.language}. + const labels = await subscribeOne( + this.hass.connection, + subscribeLabelRegistry + ).then((labs) => + Object.fromEntries(labs.map((lab) => [lab.label_id, lab.name])) + ); + const automationInspiration: string[] = []; + for (const automation of Object.values(this.hass.states)) { + const entityEntry = this.hass.entities[automation.entity_id]; + if ( + computeStateDomain(automation) !== "automation" || + automation.attributes.restored || + !automation.attributes.friendly_name || + !entityEntry + ) { + continue; + } + + let inspiration = `- ${automation.attributes.friendly_name}`; + + if (entityEntry.labels.length) { + inspiration += ` (labels: ${entityEntry.labels + .map((label) => labels[label]) + .join(", ")})`; + } + + automationInspiration.push(inspiration); + } + + const result = await generateDataAITask<{ + name: string; + description: string | undefined; + labels: string[] | undefined; + }>(this.hass, { + task_name: "frontend:automation:save", + instructions: `Suggest in language "${this.hass.language}" a name, description, and labels for the following Home Assistant automation. + +The name should be relevant to the automation's purpose. +${ + automationInspiration.length + ? `The name should be in same style as existing automations. +Suggest labels if relevant to the automation's purpose. +Only suggest labels that are already used by existing automations.` + : `The name should be short, descriptive, sentence case, and written in the language ${this.hass.language}.` +} +If the automation contains 5+ steps, include a short description. + +For inspiration, here are existing automations: +${automationInspiration.join("\n")} + +The automation configuration is as follows: ${dump(this._params.config)} `, + structure: { + name: { + description: "The name of the automation", + required: true, + selector: { + text: {}, + }, + }, + description: { + description: "A short description of the automation", + required: false, + selector: { + text: {}, + }, + }, + labels: { + description: "Labels for the automation", + required: false, + selector: { + text: { + multiple: true, + }, + }, + }, + }, }); - this._newName = result.data.trim(); + this._newName = result.data.name; + if (result.data.description) { + this._newDescription = result.data.description; + if (!this._visibleOptionals.includes("description")) { + this._visibleOptionals = [...this._visibleOptionals, "description"]; + } + } + if (result.data.labels?.length) { + // We get back label names, convert them to IDs + const newLabels: Record = Object.fromEntries( + result.data.labels.map((name) => [name, undefined]) + ); + let toFind = result.data.labels.length; + for (const [labelId, labelName] of Object.entries(labels)) { + if (labelName in newLabels && newLabels[labelName] === undefined) { + newLabels[labelName] = labelId; + toFind--; + if (toFind === 0) { + break; + } + } + } + const foundLabels = Object.values(newLabels).filter( + (labelId) => labelId !== undefined + ); + if (foundLabels.length) { + this._entryUpdates = { + ...this._entryUpdates, + labels: foundLabels, + }; + if (!this._visibleOptionals.includes("labels")) { + this._visibleOptionals = [...this._visibleOptionals, "labels"]; + } + } + } } private async _save(): Promise {