mirror of
https://github.com/home-assistant/core.git
synced 2025-11-06 17:40:11 +00:00
AI Task to store generated images in media dir (#152463)
This commit is contained in:
@@ -2,21 +2,31 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
from homeassistant.components.media_source import MediaSource, local_source
|
from homeassistant.components.media_source import MediaSource, local_source
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
|
||||||
from .const import DATA_MEDIA_SOURCE, DOMAIN, IMAGE_DIR
|
from .const import DATA_MEDIA_SOURCE, DOMAIN, IMAGE_DIR
|
||||||
|
|
||||||
|
|
||||||
async def async_get_media_source(hass: HomeAssistant) -> MediaSource:
|
async def async_get_media_source(hass: HomeAssistant) -> MediaSource:
|
||||||
"""Set up local media source."""
|
"""Set up local media source."""
|
||||||
media_dir = hass.config.path(f"{DOMAIN}/{IMAGE_DIR}")
|
media_dirs = list(hass.config.media_dirs.values())
|
||||||
|
|
||||||
|
if not media_dirs:
|
||||||
|
raise HomeAssistantError(
|
||||||
|
"AI Task media source requires at least one media directory configured"
|
||||||
|
)
|
||||||
|
|
||||||
|
media_dir = Path(media_dirs[0]) / DOMAIN / IMAGE_DIR
|
||||||
|
|
||||||
hass.data[DATA_MEDIA_SOURCE] = source = local_source.LocalSource(
|
hass.data[DATA_MEDIA_SOURCE] = source = local_source.LocalSource(
|
||||||
hass,
|
hass,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
"AI Generated Images",
|
"AI Generated Images",
|
||||||
{IMAGE_DIR: media_dir},
|
{IMAGE_DIR: str(media_dir)},
|
||||||
f"/{DOMAIN}",
|
f"/{DOMAIN}",
|
||||||
)
|
)
|
||||||
return source
|
return source
|
||||||
|
|||||||
@@ -26,7 +26,6 @@ EXCLUDE_FROM_BACKUP = [
|
|||||||
"tmp_backups/*.tar",
|
"tmp_backups/*.tar",
|
||||||
"OZW_Log.txt",
|
"OZW_Log.txt",
|
||||||
"tts/*",
|
"tts/*",
|
||||||
"ai_task/*",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
EXCLUDE_DATABASE_FROM_BACKUP = [
|
EXCLUDE_DATABASE_FROM_BACKUP = [
|
||||||
|
|||||||
@@ -1,7 +1,11 @@
|
|||||||
"""Test ai_task media source."""
|
"""Test ai_task media source."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
from homeassistant.components import media_source
|
from homeassistant.components import media_source
|
||||||
|
from homeassistant.components.ai_task.media_source import async_get_media_source
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
|
from homeassistant.exceptions import HomeAssistantError
|
||||||
|
|
||||||
|
|
||||||
async def test_local_media_source(hass: HomeAssistant, init_components: None) -> None:
|
async def test_local_media_source(hass: HomeAssistant, init_components: None) -> None:
|
||||||
@@ -9,3 +13,26 @@ async def test_local_media_source(hass: HomeAssistant, init_components: None) ->
|
|||||||
item = await media_source.async_browse_media(hass, "media-source://")
|
item = await media_source.async_browse_media(hass, "media-source://")
|
||||||
|
|
||||||
assert any(c.title == "AI Generated Images" for c in item.children)
|
assert any(c.title == "AI Generated Images" for c in item.children)
|
||||||
|
|
||||||
|
source = await async_get_media_source(hass)
|
||||||
|
assert isinstance(source, media_source.local_source.LocalSource)
|
||||||
|
assert source.name == "AI Generated Images"
|
||||||
|
assert source.domain == "ai_task"
|
||||||
|
assert list(source.media_dirs) == ["image"]
|
||||||
|
# Depending on Docker, the default is one of the two paths
|
||||||
|
assert source.media_dirs["image"] in (
|
||||||
|
"/media/ai_task/image",
|
||||||
|
hass.config.path("media/ai_task/image"),
|
||||||
|
)
|
||||||
|
assert source.url_prefix == "/ai_task"
|
||||||
|
|
||||||
|
hass.config.media_dirs = {}
|
||||||
|
|
||||||
|
with pytest.raises(
|
||||||
|
HomeAssistantError,
|
||||||
|
match="AI Task media source requires at least one media directory configured",
|
||||||
|
):
|
||||||
|
await async_get_media_source(hass)
|
||||||
|
|
||||||
|
|
||||||
|
# The following is from media_source/__init__.py for reference
|
||||||
|
|||||||
Reference in New Issue
Block a user