mirror of
https://github.com/home-assistant/core.git
synced 2025-11-06 01:19:29 +00:00
Add attachment support to AI task (#148120)
This commit is contained in:
@@ -2,17 +2,30 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, fields
|
||||
from typing import Any
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import media_source
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
|
||||
from .const import DATA_COMPONENT, DATA_PREFERENCES, AITaskEntityFeature
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class PlayMediaWithId(media_source.PlayMedia):
|
||||
"""Play media with a media content ID."""
|
||||
|
||||
media_content_id: str
|
||||
"""Media source ID to play."""
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return media source ID as a string."""
|
||||
return f"<PlayMediaWithId {self.media_content_id}>"
|
||||
|
||||
|
||||
async def async_generate_data(
|
||||
hass: HomeAssistant,
|
||||
*,
|
||||
@@ -20,6 +33,7 @@ async def async_generate_data(
|
||||
entity_id: str | None = None,
|
||||
instructions: str,
|
||||
structure: vol.Schema | None = None,
|
||||
attachments: list[dict] | None = None,
|
||||
) -> GenDataTaskResult:
|
||||
"""Run a task in the AI Task integration."""
|
||||
if entity_id is None:
|
||||
@@ -37,11 +51,37 @@ async def async_generate_data(
|
||||
f"AI Task entity {entity_id} does not support generating data"
|
||||
)
|
||||
|
||||
# Resolve attachments
|
||||
resolved_attachments: list[PlayMediaWithId] | None = None
|
||||
|
||||
if attachments:
|
||||
if AITaskEntityFeature.SUPPORT_ATTACHMENTS not in entity.supported_features:
|
||||
raise HomeAssistantError(
|
||||
f"AI Task entity {entity_id} does not support attachments"
|
||||
)
|
||||
|
||||
resolved_attachments = []
|
||||
|
||||
for attachment in attachments:
|
||||
media = await media_source.async_resolve_media(
|
||||
hass, attachment["media_content_id"], None
|
||||
)
|
||||
resolved_attachments.append(
|
||||
PlayMediaWithId(
|
||||
**{
|
||||
field.name: getattr(media, field.name)
|
||||
for field in fields(media)
|
||||
},
|
||||
media_content_id=attachment["media_content_id"],
|
||||
)
|
||||
)
|
||||
|
||||
return await entity.internal_async_generate_data(
|
||||
GenDataTask(
|
||||
name=task_name,
|
||||
instructions=instructions,
|
||||
structure=structure,
|
||||
attachments=resolved_attachments,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -59,6 +99,9 @@ class GenDataTask:
|
||||
structure: vol.Schema | None = None
|
||||
"""Optional structure for the data to be generated."""
|
||||
|
||||
attachments: list[PlayMediaWithId] | None = None
|
||||
"""List of attachments to go along the instructions."""
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return task as a string."""
|
||||
return f"<GenDataTask {self.name}: {id(self)}>"
|
||||
|
||||
Reference in New Issue
Block a user