Use HassKey in media_source (#148011)

This commit is contained in:
epenet 2025-07-03 09:56:46 +02:00 committed by GitHub
parent 691681a78a
commit a656b6e26a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 23 additions and 12 deletions

View File

@ -30,6 +30,7 @@ from .const import (
DOMAIN, DOMAIN,
MEDIA_CLASS_MAP, MEDIA_CLASS_MAP,
MEDIA_MIME_TYPES, MEDIA_MIME_TYPES,
MEDIA_SOURCE_DATA,
URI_SCHEME, URI_SCHEME,
URI_SCHEME_REGEX, URI_SCHEME_REGEX,
) )
@ -78,7 +79,7 @@ def generate_media_source_id(domain: str, identifier: str) -> str:
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the media_source component.""" """Set up the media_source component."""
hass.data[DOMAIN] = {} hass.data[MEDIA_SOURCE_DATA] = {}
websocket_api.async_register_command(hass, websocket_browse_media) websocket_api.async_register_command(hass, websocket_browse_media)
websocket_api.async_register_command(hass, websocket_resolve_media) websocket_api.async_register_command(hass, websocket_resolve_media)
frontend.async_register_built_in_panel( frontend.async_register_built_in_panel(
@ -97,7 +98,7 @@ async def _process_media_source_platform(
platform: MediaSourceProtocol, platform: MediaSourceProtocol,
) -> None: ) -> None:
"""Process a media source platform.""" """Process a media source platform."""
hass.data[DOMAIN][domain] = await platform.async_get_media_source(hass) hass.data[MEDIA_SOURCE_DATA][domain] = await platform.async_get_media_source(hass)
@callback @callback
@ -109,10 +110,10 @@ def _get_media_item(
item = MediaSourceItem.from_uri(hass, media_content_id, target_media_player) item = MediaSourceItem.from_uri(hass, media_content_id, target_media_player)
else: else:
# We default to our own domain if its only one registered # We default to our own domain if its only one registered
domain = None if len(hass.data[DOMAIN]) > 1 else DOMAIN domain = None if len(hass.data[MEDIA_SOURCE_DATA]) > 1 else DOMAIN
return MediaSourceItem(hass, domain, "", target_media_player) return MediaSourceItem(hass, domain, "", target_media_player)
if item.domain is not None and item.domain not in hass.data[DOMAIN]: if item.domain is not None and item.domain not in hass.data[MEDIA_SOURCE_DATA]:
raise UnknownMediaSource( raise UnknownMediaSource(
translation_domain=DOMAIN, translation_domain=DOMAIN,
translation_key="unknown_media_source", translation_key="unknown_media_source",

View File

@ -1,10 +1,18 @@
"""Constants for the media_source integration.""" """Constants for the media_source integration."""
from __future__ import annotations
import re import re
from typing import TYPE_CHECKING
from homeassistant.components.media_player import MediaClass from homeassistant.components.media_player import MediaClass
from homeassistant.util.hass_dict import HassKey
if TYPE_CHECKING:
from .models import MediaSource
DOMAIN = "media_source" DOMAIN = "media_source"
MEDIA_SOURCE_DATA: HassKey[dict[str, MediaSource]] = HassKey(DOMAIN)
MEDIA_MIME_TYPES = ("audio", "video", "image") MEDIA_MIME_TYPES = ("audio", "video", "image")
MEDIA_CLASS_MAP = { MEDIA_CLASS_MAP = {
"audio": MediaClass.MUSIC, "audio": MediaClass.MUSIC,

View File

@ -6,7 +6,7 @@ import logging
import mimetypes import mimetypes
from pathlib import Path from pathlib import Path
import shutil import shutil
from typing import Any from typing import Any, cast
from aiohttp import web from aiohttp import web
from aiohttp.web_request import FileField from aiohttp.web_request import FileField
@ -18,7 +18,7 @@ from homeassistant.components.media_player import BrowseError, MediaClass
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.util import raise_if_invalid_filename, raise_if_invalid_path from homeassistant.util import raise_if_invalid_filename, raise_if_invalid_path
from .const import DOMAIN, MEDIA_CLASS_MAP, MEDIA_MIME_TYPES from .const import DOMAIN, MEDIA_CLASS_MAP, MEDIA_MIME_TYPES, MEDIA_SOURCE_DATA
from .error import Unresolvable from .error import Unresolvable
from .models import BrowseMediaSource, MediaSource, MediaSourceItem, PlayMedia from .models import BrowseMediaSource, MediaSource, MediaSourceItem, PlayMedia
@ -30,7 +30,7 @@ LOGGER = logging.getLogger(__name__)
def async_setup(hass: HomeAssistant) -> None: def async_setup(hass: HomeAssistant) -> None:
"""Set up local media source.""" """Set up local media source."""
source = LocalSource(hass) source = LocalSource(hass)
hass.data[DOMAIN][DOMAIN] = source hass.data[MEDIA_SOURCE_DATA][DOMAIN] = source
hass.http.register_view(LocalMediaView(hass, source)) hass.http.register_view(LocalMediaView(hass, source))
hass.http.register_view(UploadMediaView(hass, source)) hass.http.register_view(UploadMediaView(hass, source))
websocket_api.async_register_command(hass, websocket_remove_media) websocket_api.async_register_command(hass, websocket_remove_media)
@ -352,7 +352,7 @@ async def websocket_remove_media(
connection.send_error(msg["id"], websocket_api.ERR_INVALID_FORMAT, str(err)) connection.send_error(msg["id"], websocket_api.ERR_INVALID_FORMAT, str(err))
return return
source: LocalSource = hass.data[DOMAIN][DOMAIN] source = cast(LocalSource, hass.data[MEDIA_SOURCE_DATA][DOMAIN])
try: try:
source_dir_id, location = source.async_parse_identifier(item) source_dir_id, location = source.async_parse_identifier(item)

View File

@ -3,12 +3,12 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, cast from typing import TYPE_CHECKING, Any
from homeassistant.components.media_player import BrowseMedia, MediaClass, MediaType from homeassistant.components.media_player import BrowseMedia, MediaClass, MediaType
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from .const import DOMAIN, URI_SCHEME, URI_SCHEME_REGEX from .const import MEDIA_SOURCE_DATA, URI_SCHEME, URI_SCHEME_REGEX
@dataclass(slots=True) @dataclass(slots=True)
@ -70,7 +70,7 @@ class MediaSourceItem:
can_play=False, can_play=False,
can_expand=True, can_expand=True,
) )
for source in self.hass.data[DOMAIN].values() for source in self.hass.data[MEDIA_SOURCE_DATA].values()
), ),
key=lambda item: item.title, key=lambda item: item.title,
) )
@ -85,7 +85,9 @@ class MediaSourceItem:
@callback @callback
def async_media_source(self) -> MediaSource: def async_media_source(self) -> MediaSource:
"""Return media source that owns this item.""" """Return media source that owns this item."""
return cast(MediaSource, self.hass.data[DOMAIN][self.domain]) if TYPE_CHECKING:
assert self.domain is not None
return self.hass.data[MEDIA_SOURCE_DATA][self.domain]
@classmethod @classmethod
def from_uri( def from_uri(