Use HassKey in media_source (#148011)

This commit is contained in:
epenet 2025-07-03 09:56:46 +02:00 committed by GitHub
parent 691681a78a
commit a656b6e26a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 23 additions and 12 deletions

View File

@ -30,6 +30,7 @@ from .const import (
DOMAIN,
MEDIA_CLASS_MAP,
MEDIA_MIME_TYPES,
MEDIA_SOURCE_DATA,
URI_SCHEME,
URI_SCHEME_REGEX,
)
@ -78,7 +79,7 @@ def generate_media_source_id(domain: str, identifier: str) -> str:
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the media_source component."""
hass.data[DOMAIN] = {}
hass.data[MEDIA_SOURCE_DATA] = {}
websocket_api.async_register_command(hass, websocket_browse_media)
websocket_api.async_register_command(hass, websocket_resolve_media)
frontend.async_register_built_in_panel(
@ -97,7 +98,7 @@ async def _process_media_source_platform(
platform: MediaSourceProtocol,
) -> None:
"""Process a media source platform."""
hass.data[DOMAIN][domain] = await platform.async_get_media_source(hass)
hass.data[MEDIA_SOURCE_DATA][domain] = await platform.async_get_media_source(hass)
@callback
@ -109,10 +110,10 @@ def _get_media_item(
item = MediaSourceItem.from_uri(hass, media_content_id, target_media_player)
else:
# We default to our own domain if its only one registered
domain = None if len(hass.data[DOMAIN]) > 1 else DOMAIN
domain = None if len(hass.data[MEDIA_SOURCE_DATA]) > 1 else DOMAIN
return MediaSourceItem(hass, domain, "", target_media_player)
if item.domain is not None and item.domain not in hass.data[DOMAIN]:
if item.domain is not None and item.domain not in hass.data[MEDIA_SOURCE_DATA]:
raise UnknownMediaSource(
translation_domain=DOMAIN,
translation_key="unknown_media_source",

View File

@ -1,10 +1,18 @@
"""Constants for the media_source integration."""
from __future__ import annotations
import re
from typing import TYPE_CHECKING
from homeassistant.components.media_player import MediaClass
from homeassistant.util.hass_dict import HassKey
if TYPE_CHECKING:
from .models import MediaSource
DOMAIN = "media_source"
MEDIA_SOURCE_DATA: HassKey[dict[str, MediaSource]] = HassKey(DOMAIN)
MEDIA_MIME_TYPES = ("audio", "video", "image")
MEDIA_CLASS_MAP = {
"audio": MediaClass.MUSIC,

View File

@ -6,7 +6,7 @@ import logging
import mimetypes
from pathlib import Path
import shutil
from typing import Any
from typing import Any, cast
from aiohttp import web
from aiohttp.web_request import FileField
@ -18,7 +18,7 @@ from homeassistant.components.media_player import BrowseError, MediaClass
from homeassistant.core import HomeAssistant, callback
from homeassistant.util import raise_if_invalid_filename, raise_if_invalid_path
from .const import DOMAIN, MEDIA_CLASS_MAP, MEDIA_MIME_TYPES
from .const import DOMAIN, MEDIA_CLASS_MAP, MEDIA_MIME_TYPES, MEDIA_SOURCE_DATA
from .error import Unresolvable
from .models import BrowseMediaSource, MediaSource, MediaSourceItem, PlayMedia
@ -30,7 +30,7 @@ LOGGER = logging.getLogger(__name__)
def async_setup(hass: HomeAssistant) -> None:
"""Set up local media source."""
source = LocalSource(hass)
hass.data[DOMAIN][DOMAIN] = source
hass.data[MEDIA_SOURCE_DATA][DOMAIN] = source
hass.http.register_view(LocalMediaView(hass, source))
hass.http.register_view(UploadMediaView(hass, source))
websocket_api.async_register_command(hass, websocket_remove_media)
@ -352,7 +352,7 @@ async def websocket_remove_media(
connection.send_error(msg["id"], websocket_api.ERR_INVALID_FORMAT, str(err))
return
source: LocalSource = hass.data[DOMAIN][DOMAIN]
source = cast(LocalSource, hass.data[MEDIA_SOURCE_DATA][DOMAIN])
try:
source_dir_id, location = source.async_parse_identifier(item)

View File

@ -3,12 +3,12 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, cast
from typing import TYPE_CHECKING, Any
from homeassistant.components.media_player import BrowseMedia, MediaClass, MediaType
from homeassistant.core import HomeAssistant, callback
from .const import DOMAIN, URI_SCHEME, URI_SCHEME_REGEX
from .const import MEDIA_SOURCE_DATA, URI_SCHEME, URI_SCHEME_REGEX
@dataclass(slots=True)
@ -70,7 +70,7 @@ class MediaSourceItem:
can_play=False,
can_expand=True,
)
for source in self.hass.data[DOMAIN].values()
for source in self.hass.data[MEDIA_SOURCE_DATA].values()
),
key=lambda item: item.title,
)
@ -85,7 +85,9 @@ class MediaSourceItem:
@callback
def async_media_source(self) -> MediaSource:
"""Return media source that owns this item."""
return cast(MediaSource, self.hass.data[DOMAIN][self.domain])
if TYPE_CHECKING:
assert self.domain is not None
return self.hass.data[MEDIA_SOURCE_DATA][self.domain]
@classmethod
def from_uri(