Fix playing TTS and local media source over DLNA (#134903)

Co-authored-by: Erik Montnemery <erik@montnemery.com>
This commit is contained in:
Fabio Natanael Kepler 2025-06-26 16:12:15 +01:00 committed by GitHub
parent 7b80c1c693
commit 1a92d4530e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 134 additions and 9 deletions

View File

@ -223,7 +223,7 @@ async def async_setup_auth(
# We first start with a string check to avoid parsing query params # We first start with a string check to avoid parsing query params
# for every request. # for every request.
elif ( elif (
request.method == "GET" request.method in ["GET", "HEAD"]
and SIGN_QUERY_PARAM in request.query_string and SIGN_QUERY_PARAM in request.query_string
and async_validate_signed_request(request) and async_validate_signed_request(request)
): ):

View File

@ -288,8 +288,10 @@ class ImageView(HomeAssistantView):
"""Initialize an image view.""" """Initialize an image view."""
self.component = component self.component = component
async def get(self, request: web.Request, entity_id: str) -> web.StreamResponse: async def _authenticate_request(
"""Start a GET request.""" self, request: web.Request, entity_id: str
) -> ImageEntity:
"""Authenticate request and return image entity."""
if (image_entity := self.component.get_entity(entity_id)) is None: if (image_entity := self.component.get_entity(entity_id)) is None:
raise web.HTTPNotFound raise web.HTTPNotFound
@ -306,6 +308,31 @@ class ImageView(HomeAssistantView):
# Invalid sigAuth or image entity access token # Invalid sigAuth or image entity access token
raise web.HTTPForbidden raise web.HTTPForbidden
return image_entity
async def head(self, request: web.Request, entity_id: str) -> web.Response:
"""Start a HEAD request.
This is sent by some DLNA renderers, like Samsung ones, prior to sending
the GET request.
"""
image_entity = await self._authenticate_request(request, entity_id)
# Don't use `handle` as we don't care about the stream case, we only want
# to verify that the image exists.
try:
image = await _async_get_image(image_entity, IMAGE_TIMEOUT)
except (HomeAssistantError, ValueError) as ex:
raise web.HTTPInternalServerError from ex
return web.Response(
content_type=image.content_type,
headers={"Content-Length": str(len(image.content))},
)
async def get(self, request: web.Request, entity_id: str) -> web.StreamResponse:
"""Start a GET request."""
image_entity = await self._authenticate_request(request, entity_id)
return await self.handle(request, image_entity) return await self.handle(request, image_entity)
async def handle( async def handle(
@ -317,7 +344,11 @@ class ImageView(HomeAssistantView):
except (HomeAssistantError, ValueError) as ex: except (HomeAssistantError, ValueError) as ex:
raise web.HTTPInternalServerError from ex raise web.HTTPInternalServerError from ex
return web.Response(body=image.content, content_type=image.content_type) return web.Response(
body=image.content,
content_type=image.content_type,
headers={"Content-Length": str(len(image.content))},
)
async def async_get_still_stream( async def async_get_still_stream(

View File

@ -210,10 +210,8 @@ class LocalMediaView(http.HomeAssistantView):
self.hass = hass self.hass = hass
self.source = source self.source = source
async def get( async def _validate_media_path(self, source_dir_id: str, location: str) -> Path:
self, request: web.Request, source_dir_id: str, location: str """Validate media path and return it if valid."""
) -> web.FileResponse:
"""Start a GET request."""
try: try:
raise_if_invalid_path(location) raise_if_invalid_path(location)
except ValueError as err: except ValueError as err:
@ -233,6 +231,25 @@ class LocalMediaView(http.HomeAssistantView):
if not mime_type or mime_type.split("/")[0] not in MEDIA_MIME_TYPES: if not mime_type or mime_type.split("/")[0] not in MEDIA_MIME_TYPES:
raise web.HTTPNotFound raise web.HTTPNotFound
return media_path
async def head(
self, request: web.Request, source_dir_id: str, location: str
) -> None:
"""Handle a HEAD request.
This is sent by some DLNA renderers, like Samsung ones, prior to sending
the GET request.
Check whether the location exists or not.
"""
await self._validate_media_path(source_dir_id, location)
async def get(
self, request: web.Request, source_dir_id: str, location: str
) -> web.FileResponse:
"""Handle a GET request."""
media_path = await self._validate_media_path(source_dir_id, location)
return web.FileResponse(media_path) return web.FileResponse(media_path)

View File

@ -1185,6 +1185,21 @@ class TextToSpeechView(HomeAssistantView):
"""Initialize a tts view.""" """Initialize a tts view."""
self.manager = manager self.manager = manager
async def head(self, request: web.Request, token: str) -> web.StreamResponse:
"""Start a HEAD request.
This is sent by some DLNA renderers, like Samsung ones, prior to sending
the GET request.
Check whether the token (file) exists and return its content type.
"""
stream = self.manager.token_to_stream.get(token)
if stream is None:
return web.Response(status=HTTPStatus.NOT_FOUND)
return web.Response(content_type=stream.content_type)
async def get(self, request: web.Request, token: str) -> web.StreamResponse: async def get(self, request: web.Request, token: str) -> web.StreamResponse:
"""Start a get request.""" """Start a get request."""
stream = self.manager.token_to_stream.get(token) stream = self.manager.token_to_stream.get(token)

View File

@ -305,16 +305,22 @@ async def test_auth_access_signed_path_with_refresh_token(
hass, "/", timedelta(seconds=5), refresh_token_id=refresh_token.id hass, "/", timedelta(seconds=5), refresh_token_id=refresh_token.id
) )
req = await client.head(signed_path)
assert req.status == HTTPStatus.OK
req = await client.get(signed_path) req = await client.get(signed_path)
assert req.status == HTTPStatus.OK assert req.status == HTTPStatus.OK
data = await req.json() data = await req.json()
assert data["user_id"] == refresh_token.user.id assert data["user_id"] == refresh_token.user.id
# Use signature on other path # Use signature on other path
req = await client.head(f"/another_path?{signed_path.split('?')[1]}")
assert req.status == HTTPStatus.UNAUTHORIZED
req = await client.get(f"/another_path?{signed_path.split('?')[1]}") req = await client.get(f"/another_path?{signed_path.split('?')[1]}")
assert req.status == HTTPStatus.UNAUTHORIZED assert req.status == HTTPStatus.UNAUTHORIZED
# We only allow GET # We only allow GET and HEAD
req = await client.post(signed_path) req = await client.post(signed_path)
assert req.status == HTTPStatus.UNAUTHORIZED assert req.status == HTTPStatus.UNAUTHORIZED

View File

@ -174,10 +174,22 @@ async def test_fetch_image_authenticated(
"""Test fetching an image with an authenticated client.""" """Test fetching an image with an authenticated client."""
client = await hass_client() client = await hass_client()
# Using HEAD
resp = await client.head("/api/image_proxy/image.test")
assert resp.status == HTTPStatus.OK
assert resp.content_type == "image/jpeg"
assert resp.content_length == 4
resp = await client.head("/api/image_proxy/image.unknown")
assert resp.status == HTTPStatus.NOT_FOUND
# Using GET
resp = await client.get("/api/image_proxy/image.test") resp = await client.get("/api/image_proxy/image.test")
assert resp.status == HTTPStatus.OK assert resp.status == HTTPStatus.OK
body = await resp.read() body = await resp.read()
assert body == b"Test" assert body == b"Test"
assert resp.content_type == "image/jpeg"
assert resp.content_length == 4
resp = await client.get("/api/image_proxy/image.unknown") resp = await client.get("/api/image_proxy/image.unknown")
assert resp.status == HTTPStatus.NOT_FOUND assert resp.status == HTTPStatus.NOT_FOUND
@ -260,10 +272,19 @@ async def test_fetch_image_url_success(
client = await hass_client() client = await hass_client()
# Using HEAD
resp = await client.head("/api/image_proxy/image.test")
assert resp.status == HTTPStatus.OK
assert resp.content_type == "image/png"
assert resp.content_length == 4
# Using GET
resp = await client.get("/api/image_proxy/image.test") resp = await client.get("/api/image_proxy/image.test")
assert resp.status == HTTPStatus.OK assert resp.status == HTTPStatus.OK
body = await resp.read() body = await resp.read()
assert body == b"Test" assert body == b"Test"
assert resp.content_type == "image/png"
assert resp.content_length == 4
@respx.mock @respx.mock

View File

@ -105,6 +105,9 @@ async def test_media_view(
client = await hass_client() client = await hass_client()
# Protects against non-existent files # Protects against non-existent files
resp = await client.head("/media/local/invalid.txt")
assert resp.status == HTTPStatus.NOT_FOUND
resp = await client.get("/media/local/invalid.txt") resp = await client.get("/media/local/invalid.txt")
assert resp.status == HTTPStatus.NOT_FOUND assert resp.status == HTTPStatus.NOT_FOUND
@ -112,14 +115,23 @@ async def test_media_view(
assert resp.status == HTTPStatus.NOT_FOUND assert resp.status == HTTPStatus.NOT_FOUND
# Protects against non-media files # Protects against non-media files
resp = await client.head("/media/local/not_media.txt")
assert resp.status == HTTPStatus.NOT_FOUND
resp = await client.get("/media/local/not_media.txt") resp = await client.get("/media/local/not_media.txt")
assert resp.status == HTTPStatus.NOT_FOUND assert resp.status == HTTPStatus.NOT_FOUND
# Protects against unknown local media sources # Protects against unknown local media sources
resp = await client.head("/media/unknown_source/not_media.txt")
assert resp.status == HTTPStatus.NOT_FOUND
resp = await client.get("/media/unknown_source/not_media.txt") resp = await client.get("/media/unknown_source/not_media.txt")
assert resp.status == HTTPStatus.NOT_FOUND assert resp.status == HTTPStatus.NOT_FOUND
# Fetch available media # Fetch available media
resp = await client.head("/media/local/test.mp3")
assert resp.status == HTTPStatus.OK
resp = await client.get("/media/local/test.mp3") resp = await client.get("/media/local/test.mp3")
assert resp.status == HTTPStatus.OK assert resp.status == HTTPStatus.OK

View File

@ -916,6 +916,29 @@ async def test_web_view_wrong_file(
assert req.status == HTTPStatus.NOT_FOUND assert req.status == HTTPStatus.NOT_FOUND
@pytest.mark.parametrize(
("setup", "expected_url_suffix"),
[("mock_setup", "test"), ("mock_config_entry_setup", "tts.test")],
indirect=["setup"],
)
async def test_web_view_wrong_file_with_head_request(
hass: HomeAssistant,
hass_client: ClientSessionGenerator,
setup: str,
expected_url_suffix: str,
) -> None:
"""Set up a TTS platform and receive wrong file from web."""
client = await hass_client()
url = (
"/api/tts_proxy/42f18378fd4393d18c8dd11d03fa9563c1e54491"
f"_en-us_-_{expected_url_suffix}.mp3"
)
req = await client.head(url)
assert req.status == HTTPStatus.NOT_FOUND
@pytest.mark.parametrize( @pytest.mark.parametrize(
("setup", "expected_url_suffix"), ("setup", "expected_url_suffix"),
[("mock_setup", "test"), ("mock_config_entry_setup", "tts.test")], [("mock_setup", "test"), ("mock_config_entry_setup", "tts.test")],