Some tweaks for media source (#64641)

This commit is contained in:
Paulus Schoutsen 2022-01-21 11:26:06 -08:00 committed by GitHub
parent e0e6853968
commit c72c39e9a1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 85 additions and 29 deletions

View File

@ -1271,3 +1271,7 @@ class BrowseMedia:
proposed_class = self.children[0].media_class proposed_class = self.children[0].media_class
if all(child.media_class == proposed_class for child in self.children): if all(child.media_class == proposed_class for child in self.children):
self.children_media_class = proposed_class self.children_media_class = proposed_class
def __repr__(self):
"""Return representation of browse media."""
return f"<BrowseMedia {self.title} ({self.media_class})>"

View File

@ -1,6 +1,7 @@
"""The media_source integration.""" """The media_source integration."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable
from datetime import timedelta from datetime import timedelta
from typing import Any from typing import Any
from urllib.parse import quote from urllib.parse import quote
@ -9,8 +10,11 @@ import voluptuous as vol
from homeassistant.components import frontend, websocket_api from homeassistant.components import frontend, websocket_api
from homeassistant.components.http.auth import async_sign_path from homeassistant.components.http.auth import async_sign_path
from homeassistant.components.media_player.const import ATTR_MEDIA_CONTENT_ID from homeassistant.components.media_player import (
from homeassistant.components.media_player.errors import BrowseError ATTR_MEDIA_CONTENT_ID,
BrowseError,
BrowseMedia,
)
from homeassistant.components.websocket_api import ActiveConnection from homeassistant.components.websocket_api import ActiveConnection
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.integration_platform import ( from homeassistant.helpers.integration_platform import (
@ -19,12 +23,26 @@ from homeassistant.helpers.integration_platform import (
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import bind_hass from homeassistant.loader import bind_hass
from . import local_source, models from . import local_source
from .const import DOMAIN, URI_SCHEME, URI_SCHEME_REGEX from .const import DOMAIN, URI_SCHEME, URI_SCHEME_REGEX
from .error import Unresolvable from .error import MediaSourceError, Unresolvable
from .models import BrowseMediaSource, MediaSourceItem, PlayMedia
DEFAULT_EXPIRY_TIME = 3600 * 24 DEFAULT_EXPIRY_TIME = 3600 * 24
__all__ = [
"DOMAIN",
"is_media_source_id",
"generate_media_source_id",
"async_browse_media",
"async_resolve_media",
"BrowseMediaSource",
"PlayMedia",
"MediaSourceItem",
"Unresolvable",
"MediaSourceError",
]
def is_media_source_id(media_content_id: str) -> bool: def is_media_source_id(media_content_id: str) -> bool:
"""Test if identifier is a media source.""" """Test if identifier is a media source."""
@ -64,29 +82,43 @@ async def _process_media_source_platform(
@callback @callback
def _get_media_item( def _get_media_item(
hass: HomeAssistant, media_content_id: str | None hass: HomeAssistant, media_content_id: str | None
) -> models.MediaSourceItem: ) -> MediaSourceItem:
"""Return media item.""" """Return media item."""
if media_content_id: if media_content_id:
return models.MediaSourceItem.from_uri(hass, media_content_id) return MediaSourceItem.from_uri(hass, media_content_id)
# We default to our own domain if its only one registered # We default to our own domain if its only one registered
domain = None if len(hass.data[DOMAIN]) > 1 else DOMAIN domain = None if len(hass.data[DOMAIN]) > 1 else DOMAIN
return models.MediaSourceItem(hass, domain, "") return MediaSourceItem(hass, domain, "")
@bind_hass @bind_hass
async def async_browse_media( async def async_browse_media(
hass: HomeAssistant, media_content_id: str hass: HomeAssistant,
) -> models.BrowseMediaSource: media_content_id: str,
*,
content_filter: Callable[[BrowseMedia], bool] | None = None,
) -> BrowseMediaSource:
"""Return media player browse media results.""" """Return media player browse media results."""
return await _get_media_item(hass, media_content_id).async_browse() if DOMAIN not in hass.data:
raise BrowseError("Media Source not loaded")
item = await _get_media_item(hass, media_content_id).async_browse()
if content_filter is None or item.children is None:
return item
item.children = [
child for child in item.children if child.can_expand or content_filter(child)
]
return item
@bind_hass @bind_hass
async def async_resolve_media( async def async_resolve_media(hass: HomeAssistant, media_content_id: str) -> PlayMedia:
hass: HomeAssistant, media_content_id: str
) -> models.PlayMedia:
"""Get info to play media.""" """Get info to play media."""
if DOMAIN not in hass.data:
raise Unresolvable("Media Source not loaded")
return await _get_media_item(hass, media_content_id).async_resolve() return await _get_media_item(hass, media_content_id).async_resolve()

View File

@ -1,22 +1,22 @@
"""Test Media Source initialization.""" """Test Media Source initialization."""
from unittest.mock import patch from unittest.mock import Mock, patch
from urllib.parse import quote from urllib.parse import quote
import pytest import pytest
from homeassistant.components import media_source from homeassistant.components import media_source
from homeassistant.components.media_player.const import MEDIA_CLASS_DIRECTORY from homeassistant.components.media_player import MEDIA_CLASS_DIRECTORY, BrowseError
from homeassistant.components.media_player.errors import BrowseError
from homeassistant.components.media_source import const from homeassistant.components.media_source import const
from homeassistant.components.media_source.error import Unresolvable
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
async def test_is_media_source_id(): async def test_is_media_source_id():
"""Test media source validation.""" """Test media source validation."""
assert media_source.is_media_source_id(const.URI_SCHEME) assert media_source.is_media_source_id(media_source.URI_SCHEME)
assert media_source.is_media_source_id(f"{const.URI_SCHEME}domain") assert media_source.is_media_source_id(f"{media_source.URI_SCHEME}domain")
assert media_source.is_media_source_id(f"{const.URI_SCHEME}domain/identifier") assert media_source.is_media_source_id(
f"{media_source.URI_SCHEME}domain/identifier"
)
assert not media_source.is_media_source_id("test") assert not media_source.is_media_source_id("test")
@ -39,7 +39,7 @@ async def test_generate_media_source_id():
async def test_async_browse_media(hass): async def test_async_browse_media(hass):
"""Test browse media.""" """Test browse media."""
assert await async_setup_component(hass, const.DOMAIN, {}) assert await async_setup_component(hass, media_source.DOMAIN, {})
await hass.async_block_till_done() await hass.async_block_till_done()
# Test non-media ignored (/media has test.mp3 and not_media.txt) # Test non-media ignored (/media has test.mp3 and not_media.txt)
@ -48,6 +48,17 @@ async def test_async_browse_media(hass):
assert media.title == "media" assert media.title == "media"
assert len(media.children) == 2 assert len(media.children) == 2
# Test content filter
media = await media_source.async_browse_media(
hass,
"",
content_filter=lambda item: item.media_content_type.startswith("video/"),
)
assert isinstance(media, media_source.models.BrowseMediaSource)
assert media.title == "media"
assert len(media.children) == 1, media.children
media.children[0].title = "Epic Sax Guy 10 Hours"
# Test invalid media content # Test invalid media content
with pytest.raises(ValueError): with pytest.raises(ValueError):
await media_source.async_browse_media(hass, "invalid") await media_source.async_browse_media(hass, "invalid")
@ -61,35 +72,35 @@ async def test_async_browse_media(hass):
async def test_async_resolve_media(hass): async def test_async_resolve_media(hass):
"""Test browse media.""" """Test browse media."""
assert await async_setup_component(hass, const.DOMAIN, {}) assert await async_setup_component(hass, media_source.DOMAIN, {})
await hass.async_block_till_done() await hass.async_block_till_done()
media = await media_source.async_resolve_media( media = await media_source.async_resolve_media(
hass, hass,
media_source.generate_media_source_id(const.DOMAIN, "local/test.mp3"), media_source.generate_media_source_id(media_source.DOMAIN, "local/test.mp3"),
) )
assert isinstance(media, media_source.models.PlayMedia) assert isinstance(media, media_source.models.PlayMedia)
async def test_async_unresolve_media(hass): async def test_async_unresolve_media(hass):
"""Test browse media.""" """Test browse media."""
assert await async_setup_component(hass, const.DOMAIN, {}) assert await async_setup_component(hass, media_source.DOMAIN, {})
await hass.async_block_till_done() await hass.async_block_till_done()
# Test no media content # Test no media content
with pytest.raises(Unresolvable): with pytest.raises(media_source.Unresolvable):
await media_source.async_resolve_media(hass, "") await media_source.async_resolve_media(hass, "")
async def test_websocket_browse_media(hass, hass_ws_client): async def test_websocket_browse_media(hass, hass_ws_client):
"""Test browse media websocket.""" """Test browse media websocket."""
assert await async_setup_component(hass, const.DOMAIN, {}) assert await async_setup_component(hass, media_source.DOMAIN, {})
await hass.async_block_till_done() await hass.async_block_till_done()
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
media = media_source.models.BrowseMediaSource( media = media_source.models.BrowseMediaSource(
domain=const.DOMAIN, domain=media_source.DOMAIN,
identifier="/media", identifier="/media",
title="Local Media", title="Local Media",
media_class=MEDIA_CLASS_DIRECTORY, media_class=MEDIA_CLASS_DIRECTORY,
@ -137,7 +148,7 @@ async def test_websocket_browse_media(hass, hass_ws_client):
@pytest.mark.parametrize("filename", ["test.mp3", "Epic Sax Guy 10 Hours.mp4"]) @pytest.mark.parametrize("filename", ["test.mp3", "Epic Sax Guy 10 Hours.mp4"])
async def test_websocket_resolve_media(hass, hass_ws_client, filename): async def test_websocket_resolve_media(hass, hass_ws_client, filename):
"""Test browse media websocket.""" """Test browse media websocket."""
assert await async_setup_component(hass, const.DOMAIN, {}) assert await async_setup_component(hass, media_source.DOMAIN, {})
await hass.async_block_till_done() await hass.async_block_till_done()
client = await hass_ws_client(hass) client = await hass_ws_client(hass)
@ -152,7 +163,7 @@ async def test_websocket_resolve_media(hass, hass_ws_client, filename):
{ {
"id": 1, "id": 1,
"type": "media_source/resolve_media", "type": "media_source/resolve_media",
"media_content_id": f"{const.URI_SCHEME}{const.DOMAIN}/local/{filename}", "media_content_id": f"{const.URI_SCHEME}{media_source.DOMAIN}/local/{filename}",
} }
) )
@ -180,3 +191,12 @@ async def test_websocket_resolve_media(hass, hass_ws_client, filename):
assert not msg["success"] assert not msg["success"]
assert msg["error"]["code"] == "resolve_media_failed" assert msg["error"]["code"] == "resolve_media_failed"
assert msg["error"]["message"] == "test" assert msg["error"]["message"] == "test"
async def test_browse_resolve_without_setup():
"""Test browse and resolve work without being setup."""
with pytest.raises(BrowseError):
await media_source.async_browse_media(Mock(data={}), None)
with pytest.raises(media_source.Unresolvable):
await media_source.async_resolve_media(Mock(data={}), None)