Compare commits

...

1 Commits

Author SHA1 Message Date
Paulus Schoutsen
05ec051bf9 Standardizing snapshotting camera 2025-09-13 23:50:27 -04:00
4 changed files with 114 additions and 72 deletions

View File

@@ -2,6 +2,7 @@
from __future__ import annotations
from contextlib import AsyncExitStack, asynccontextmanager
from dataclasses import dataclass
from datetime import datetime, timedelta
import io
@@ -14,7 +15,7 @@ import voluptuous as vol
from homeassistant.components import camera, conversation, media_source
from homeassistant.components.http.auth import async_sign_path
from homeassistant.core import HomeAssistant, ServiceResponse, callback
from homeassistant.core import HomeAssistant, ServiceResponse
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import llm
from homeassistant.helpers.chat_session import ChatSession, async_get_chat_session
@@ -42,45 +43,21 @@ def _save_camera_snapshot(image: camera.Image) -> Path:
return Path(temp_file.name)
@asynccontextmanager
async def _resolve_attachments(
hass: HomeAssistant,
session: ChatSession,
attachments: list[dict] | None = None,
) -> list[conversation.Attachment]:
"""Resolve attachments for a task."""
resolved_attachments: list[conversation.Attachment] = []
created_files: list[Path] = []
async with AsyncExitStack() as stack:
resolved_attachments: list[conversation.Attachment] = []
for attachment in attachments or []:
media_content_id = attachment["media_content_id"]
# Special case for camera media sources
if media_content_id.startswith("media-source://camera/"):
# Extract entity_id from the media content ID
entity_id = media_content_id.removeprefix("media-source://camera/")
# Get snapshot from camera
image = await camera.async_get_image(hass, entity_id)
temp_filename = await hass.async_add_executor_job(
_save_camera_snapshot, image
for attachment in attachments or []:
media_content_id = attachment["media_content_id"]
media = await stack.enter_async_context(
media_source.async_resolve_with_path(hass, media_content_id, None)
)
created_files.append(temp_filename)
resolved_attachments.append(
conversation.Attachment(
media_content_id=media_content_id,
mime_type=image.content_type,
path=temp_filename,
)
)
else:
# Handle regular media sources
media = await media_source.async_resolve_media(hass, media_content_id, None)
if media.path is None:
raise HomeAssistantError(
"Only local attachments are currently supported"
)
resolved_attachments.append(
conversation.Attachment(
media_content_id=media_content_id,
@@ -89,22 +66,7 @@ async def _resolve_attachments(
)
)
if not created_files:
return resolved_attachments
def cleanup_files() -> None:
"""Cleanup temporary files."""
for file in created_files:
file.unlink(missing_ok=True)
@callback
def cleanup_files_callback() -> None:
"""Cleanup temporary files."""
hass.async_add_executor_job(cleanup_files)
session.async_on_cleanup(cleanup_files_callback)
return resolved_attachments
yield resolved_attachments
async def async_generate_data(
@@ -142,18 +104,19 @@ async def async_generate_data(
)
with async_get_chat_session(hass) as session:
resolved_attachments = await _resolve_attachments(hass, session, attachments)
return await entity.internal_async_generate_data(
session,
GenDataTask(
name=task_name,
instructions=instructions,
structure=structure,
attachments=resolved_attachments or None,
llm_api=llm_api,
),
)
async with _resolve_attachments(
hass, session, attachments
) as resolved_attachments:
return await entity.internal_async_generate_data(
session,
GenDataTask(
name=task_name,
instructions=instructions,
structure=structure,
attachments=resolved_attachments or None,
llm_api=llm_api,
),
)
async def async_generate_image(
@@ -189,16 +152,17 @@ async def async_generate_image(
)
with async_get_chat_session(hass) as session:
resolved_attachments = await _resolve_attachments(hass, session, attachments)
task_result = await entity.internal_async_generate_image(
session,
GenImageTask(
name=task_name,
instructions=instructions,
attachments=resolved_attachments or None,
),
)
async with _resolve_attachments(
hass, session, attachments
) as resolved_attachments:
task_result = await entity.internal_async_generate_image(
session,
GenImageTask(
name=task_name,
instructions=instructions,
attachments=resolved_attachments or None,
),
)
service_result = task_result.as_dict()
image_data = service_result.pop("image_data")

View File

@@ -3,6 +3,10 @@
from __future__ import annotations
import asyncio
from contextlib import asynccontextmanager
import mimetypes
from pathlib import Path
import tempfile
from homeassistant.components.media_player import BrowseError, MediaClass
from homeassistant.components.media_source import (
@@ -17,7 +21,7 @@ from homeassistant.const import ATTR_FRIENDLY_NAME
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from . import Camera, _async_stream_endpoint_url
from . import Camera, Image, _async_stream_endpoint_url, async_get_image
from .const import DATA_COMPONENT, DOMAIN, StreamType
@@ -84,6 +88,30 @@ class CameraMediaSource(MediaSource):
return PlayMedia(url, FORMAT_CONTENT_TYPE[HLS_PROVIDER])
@asynccontextmanager
async def async_resolve_with_path(self, item: MediaSourceItem) -> PlayMedia:
"""Resolve to playable item with path."""
media = await self.async_resolve_media(item)
entity_id = item.identifier
image = await async_get_image(self.hass, entity_id)
media.path = await self.hass.async_add_executor_job(
self._save_camera_snapshot, image
)
yield media
await self.hass.async_add_executor_job(media.path.unlink)
def _save_camera_snapshot(self, image: Image) -> Path:
"""Save camera snapshot to temp file."""
with tempfile.NamedTemporaryFile(
mode="wb",
suffix=mimetypes.guess_extension(image.content_type, False),
delete=False,
) as temp_file:
temp_file.write(image.content)
return Path(temp_file.name)
async def async_browse_media(
self,
item: MediaSourceItem,

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
from collections.abc import Callable
from contextlib import asynccontextmanager
from typing import Any, Protocol
import voluptuous as vol
@@ -197,6 +198,30 @@ async def async_resolve_media(
return await item.async_resolve()
@asynccontextmanager
async def async_resolve_with_path(
hass: HomeAssistant, media_content_id: str, target_media_player: str | None
) -> PlayMedia:
"""Get info to play media."""
if DOMAIN not in hass.data:
raise Unresolvable("Media Source not loaded")
try:
item = _get_media_item(hass, media_content_id, target_media_player)
except ValueError as err:
raise Unresolvable(
translation_domain=DOMAIN,
translation_key="resolve_media_failed",
translation_placeholders={
"media_content_id": str(media_content_id),
"error": str(err),
},
) from err
async with item.async_resolve_with_path() as media:
yield media
@websocket_api.websocket_command(
{
vol.Required("type"): "media_source/browse_media",

View File

@@ -2,13 +2,15 @@
from __future__ import annotations
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from homeassistant.components.media_player import BrowseMedia, MediaClass, MediaType
from homeassistant.core import HomeAssistant, callback
from .const import MEDIA_SOURCE_DATA, URI_SCHEME, URI_SCHEME_REGEX
from .const import DOMAIN, MEDIA_SOURCE_DATA, URI_SCHEME, URI_SCHEME_REGEX
from .error import Unresolvable
if TYPE_CHECKING:
from pathlib import Path
@@ -103,6 +105,12 @@ class MediaSourceItem:
assert self.domain is not None
return self.hass.data[MEDIA_SOURCE_DATA][self.domain]
@asynccontextmanager
async def async_resolve_with_path(self) -> PlayMedia:
"""Resolve to playable item with path."""
async with self.async_media_source().async_resolve_with_path(self) as media:
yield media
@classmethod
def from_uri(
cls, hass: HomeAssistant, uri: str, target_media_player: str | None
@@ -132,6 +140,23 @@ class MediaSource:
"""Resolve a media item to a playable item."""
raise NotImplementedError
@asynccontextmanager
async def async_resolve_with_path(self, item: MediaSourceItem) -> PlayMedia:
"""Resolve to playable item with path."""
item = await self.async_resolve_media(item)
if item.path is None:
raise Unresolvable(
translation_domain=DOMAIN,
# TODO translations
translation_key="resolve_media_path_failed",
translation_placeholders={
"media_content_id": item.media_source_id,
},
)
yield item
async def async_browse_media(self, item: MediaSourceItem) -> BrowseMediaSource:
"""Browse media."""
raise NotImplementedError