Fix playing TTS and local media source over DLNA (#134903)

Co-authored-by: Erik Montnemery <erik@montnemery.com>
This commit is contained in:
Fabio Natanael Kepler 2025-06-26 16:12:15 +01:00 committed by GitHub
parent 7b80c1c693
commit 1a92d4530e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 134 additions and 9 deletions

View File

@ -223,7 +223,7 @@ async def async_setup_auth(
# We first start with a string check to avoid parsing query params
# for every request.
elif (
request.method == "GET"
request.method in ["GET", "HEAD"]
and SIGN_QUERY_PARAM in request.query_string
and async_validate_signed_request(request)
):

View File

@ -288,8 +288,10 @@ class ImageView(HomeAssistantView):
"""Initialize an image view."""
self.component = component
async def get(self, request: web.Request, entity_id: str) -> web.StreamResponse:
"""Start a GET request."""
async def _authenticate_request(
self, request: web.Request, entity_id: str
) -> ImageEntity:
"""Authenticate request and return image entity."""
if (image_entity := self.component.get_entity(entity_id)) is None:
raise web.HTTPNotFound
@ -306,6 +308,31 @@ class ImageView(HomeAssistantView):
# Invalid sigAuth or image entity access token
raise web.HTTPForbidden
return image_entity
async def head(self, request: web.Request, entity_id: str) -> web.Response:
"""Start a HEAD request.
This is sent by some DLNA renderers, like Samsung ones, prior to sending
the GET request.
"""
image_entity = await self._authenticate_request(request, entity_id)
# Don't use `handle` as we don't care about the stream case, we only want
# to verify that the image exists.
try:
image = await _async_get_image(image_entity, IMAGE_TIMEOUT)
except (HomeAssistantError, ValueError) as ex:
raise web.HTTPInternalServerError from ex
return web.Response(
content_type=image.content_type,
headers={"Content-Length": str(len(image.content))},
)
async def get(self, request: web.Request, entity_id: str) -> web.StreamResponse:
"""Start a GET request."""
image_entity = await self._authenticate_request(request, entity_id)
return await self.handle(request, image_entity)
async def handle(
@ -317,7 +344,11 @@ class ImageView(HomeAssistantView):
except (HomeAssistantError, ValueError) as ex:
raise web.HTTPInternalServerError from ex
return web.Response(body=image.content, content_type=image.content_type)
return web.Response(
body=image.content,
content_type=image.content_type,
headers={"Content-Length": str(len(image.content))},
)
async def async_get_still_stream(

View File

@ -210,10 +210,8 @@ class LocalMediaView(http.HomeAssistantView):
self.hass = hass
self.source = source
async def get(
self, request: web.Request, source_dir_id: str, location: str
) -> web.FileResponse:
"""Start a GET request."""
async def _validate_media_path(self, source_dir_id: str, location: str) -> Path:
"""Validate media path and return it if valid."""
try:
raise_if_invalid_path(location)
except ValueError as err:
@ -233,6 +231,25 @@ class LocalMediaView(http.HomeAssistantView):
if not mime_type or mime_type.split("/")[0] not in MEDIA_MIME_TYPES:
raise web.HTTPNotFound
return media_path
async def head(
self, request: web.Request, source_dir_id: str, location: str
) -> None:
"""Handle a HEAD request.
This is sent by some DLNA renderers, like Samsung ones, prior to sending
the GET request.
Check whether the location exists or not.
"""
await self._validate_media_path(source_dir_id, location)
async def get(
self, request: web.Request, source_dir_id: str, location: str
) -> web.FileResponse:
"""Handle a GET request."""
media_path = await self._validate_media_path(source_dir_id, location)
return web.FileResponse(media_path)

View File

@ -1185,6 +1185,21 @@ class TextToSpeechView(HomeAssistantView):
"""Initialize a tts view."""
self.manager = manager
async def head(self, request: web.Request, token: str) -> web.StreamResponse:
"""Start a HEAD request.
This is sent by some DLNA renderers, like Samsung ones, prior to sending
the GET request.
Check whether the token (file) exists and return its content type.
"""
stream = self.manager.token_to_stream.get(token)
if stream is None:
return web.Response(status=HTTPStatus.NOT_FOUND)
return web.Response(content_type=stream.content_type)
async def get(self, request: web.Request, token: str) -> web.StreamResponse:
"""Start a get request."""
stream = self.manager.token_to_stream.get(token)

View File

@ -305,16 +305,22 @@ async def test_auth_access_signed_path_with_refresh_token(
hass, "/", timedelta(seconds=5), refresh_token_id=refresh_token.id
)
req = await client.head(signed_path)
assert req.status == HTTPStatus.OK
req = await client.get(signed_path)
assert req.status == HTTPStatus.OK
data = await req.json()
assert data["user_id"] == refresh_token.user.id
# Use signature on other path
req = await client.head(f"/another_path?{signed_path.split('?')[1]}")
assert req.status == HTTPStatus.UNAUTHORIZED
req = await client.get(f"/another_path?{signed_path.split('?')[1]}")
assert req.status == HTTPStatus.UNAUTHORIZED
# We only allow GET
# We only allow GET and HEAD
req = await client.post(signed_path)
assert req.status == HTTPStatus.UNAUTHORIZED

View File

@ -174,10 +174,22 @@ async def test_fetch_image_authenticated(
"""Test fetching an image with an authenticated client."""
client = await hass_client()
# Using HEAD
resp = await client.head("/api/image_proxy/image.test")
assert resp.status == HTTPStatus.OK
assert resp.content_type == "image/jpeg"
assert resp.content_length == 4
resp = await client.head("/api/image_proxy/image.unknown")
assert resp.status == HTTPStatus.NOT_FOUND
# Using GET
resp = await client.get("/api/image_proxy/image.test")
assert resp.status == HTTPStatus.OK
body = await resp.read()
assert body == b"Test"
assert resp.content_type == "image/jpeg"
assert resp.content_length == 4
resp = await client.get("/api/image_proxy/image.unknown")
assert resp.status == HTTPStatus.NOT_FOUND
@ -260,10 +272,19 @@ async def test_fetch_image_url_success(
client = await hass_client()
# Using HEAD
resp = await client.head("/api/image_proxy/image.test")
assert resp.status == HTTPStatus.OK
assert resp.content_type == "image/png"
assert resp.content_length == 4
# Using GET
resp = await client.get("/api/image_proxy/image.test")
assert resp.status == HTTPStatus.OK
body = await resp.read()
assert body == b"Test"
assert resp.content_type == "image/png"
assert resp.content_length == 4
@respx.mock

View File

@ -105,6 +105,9 @@ async def test_media_view(
client = await hass_client()
# Protects against non-existent files
resp = await client.head("/media/local/invalid.txt")
assert resp.status == HTTPStatus.NOT_FOUND
resp = await client.get("/media/local/invalid.txt")
assert resp.status == HTTPStatus.NOT_FOUND
@ -112,14 +115,23 @@ async def test_media_view(
assert resp.status == HTTPStatus.NOT_FOUND
# Protects against non-media files
resp = await client.head("/media/local/not_media.txt")
assert resp.status == HTTPStatus.NOT_FOUND
resp = await client.get("/media/local/not_media.txt")
assert resp.status == HTTPStatus.NOT_FOUND
# Protects against unknown local media sources
resp = await client.head("/media/unknown_source/not_media.txt")
assert resp.status == HTTPStatus.NOT_FOUND
resp = await client.get("/media/unknown_source/not_media.txt")
assert resp.status == HTTPStatus.NOT_FOUND
# Fetch available media
resp = await client.head("/media/local/test.mp3")
assert resp.status == HTTPStatus.OK
resp = await client.get("/media/local/test.mp3")
assert resp.status == HTTPStatus.OK

View File

@ -916,6 +916,29 @@ async def test_web_view_wrong_file(
assert req.status == HTTPStatus.NOT_FOUND
@pytest.mark.parametrize(
("setup", "expected_url_suffix"),
[("mock_setup", "test"), ("mock_config_entry_setup", "tts.test")],
indirect=["setup"],
)
async def test_web_view_wrong_file_with_head_request(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
setup: str,
expected_url_suffix: str,
) -> None:
"""Set up a TTS platform and receive wrong file from web."""
client = await hass_client()
url = (
"/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491"
f"_en-us_-_{expected_url_suffix}.mp3"
)
req = await client.head(url)
assert req.status == HTTPStatus.NOT_FOUND
@pytest.mark.parametrize(
("setup", "expected_url_suffix"),
[("mock_setup", "test"), ("mock_config_entry_setup", "tts.test")],