Refactor + strictly-type image component (#81808)

* image: refactor size validation to use partition

* image: give _generate_thumbnail types and use partition

* image: become strictly typed
This commit is contained in:
Aarni Koskela 2022-11-09 16:36:03 +02:00 committed by GitHub
parent ec316e94ed
commit 5a6f7e66cb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 65 additions and 18 deletions

View File

@ -151,6 +151,7 @@ homeassistant.components.http.*
homeassistant.components.huawei_lte.* homeassistant.components.huawei_lte.*
homeassistant.components.hyperion.* homeassistant.components.hyperion.*
homeassistant.components.ibeacon.* homeassistant.components.ibeacon.*
homeassistant.components.image.*
homeassistant.components.image_processing.* homeassistant.components.image_processing.*
homeassistant.components.input_button.* homeassistant.components.input_button.*
homeassistant.components.input_select.* homeassistant.components.input_select.*

View File

@ -6,6 +6,7 @@ import logging
import pathlib import pathlib
import secrets import secrets
import shutil import shutil
from typing import Any
from PIL import Image, ImageOps, UnidentifiedImageError from PIL import Image, ImageOps, UnidentifiedImageError
from aiohttp import hdrs, web from aiohttp import hdrs, web
@ -71,7 +72,7 @@ class ImageStorageCollection(collection.StorageCollection):
self.async_add_listener(self._change_listener) self.async_add_listener(self._change_listener)
self.image_dir = image_dir self.image_dir = image_dir
async def _process_create_data(self, data: dict) -> dict: async def _process_create_data(self, data: dict[str, Any]) -> dict[str, Any]:
"""Validate the config is valid.""" """Validate the config is valid."""
data = self.CREATE_SCHEMA(dict(data)) data = self.CREATE_SCHEMA(dict(data))
uploaded_file: FileField = data["file"] uploaded_file: FileField = data["file"]
@ -88,7 +89,7 @@ class ImageStorageCollection(collection.StorageCollection):
return data return data
def _move_data(self, data): def _move_data(self, data: dict[str, Any]) -> int:
"""Move data.""" """Move data."""
uploaded_file: FileField = data.pop("file") uploaded_file: FileField = data.pop("file")
@ -119,15 +120,24 @@ class ImageStorageCollection(collection.StorageCollection):
return media_file.stat().st_size return media_file.stat().st_size
@callback @callback
def _get_suggested_id(self, info: dict) -> str: def _get_suggested_id(self, info: dict[str, Any]) -> str:
"""Suggest an ID based on the config.""" """Suggest an ID based on the config."""
return info[CONF_ID] return str(info[CONF_ID])
async def _update_data(self, data: dict, update_data: dict) -> dict: async def _update_data(
self,
data: dict[str, Any],
update_data: dict[str, Any],
) -> dict[str, Any]:
"""Return a new updated data object.""" """Return a new updated data object."""
return {**data, **self.UPDATE_SCHEMA(update_data)} return {**data, **self.UPDATE_SCHEMA(update_data)}
async def _change_listener(self, change_type, item_id, data): async def _change_listener(
self,
change_type: str,
item_id: str,
data: dict[str, Any],
) -> None:
"""Handle change.""" """Handle change."""
if change_type != collection.CHANGE_REMOVED: if change_type != collection.CHANGE_REMOVED:
return return
@ -141,7 +151,7 @@ class ImageUploadView(HomeAssistantView):
url = "/api/image/upload" url = "/api/image/upload"
name = "api:image:upload" name = "api:image:upload"
async def post(self, request): async def post(self, request: web.Request) -> web.Response:
"""Handle upload.""" """Handle upload."""
# Increase max payload # Increase max payload
request._client_max_size = MAX_SIZE # pylint: disable=protected-access request._client_max_size = MAX_SIZE # pylint: disable=protected-access
@ -159,26 +169,27 @@ class ImageServeView(HomeAssistantView):
requires_auth = False requires_auth = False
def __init__( def __init__(
self, image_folder: pathlib.Path, image_collection: ImageStorageCollection self,
image_folder: pathlib.Path,
image_collection: ImageStorageCollection,
) -> None: ) -> None:
"""Initialize image serve view.""" """Initialize image serve view."""
self.transform_lock = asyncio.Lock() self.transform_lock = asyncio.Lock()
self.image_folder = image_folder self.image_folder = image_folder
self.image_collection = image_collection self.image_collection = image_collection
async def get(self, request: web.Request, image_id: str, filename: str): async def get(
self,
request: web.Request,
image_id: str,
filename: str,
) -> web.FileResponse:
"""Serve image.""" """Serve image."""
image_size = filename.split("-", 1)[0]
try: try:
parts = image_size.split("x", 1) width, height = _validate_size_from_filename(filename)
width = int(parts[0])
height = int(parts[1])
except (ValueError, IndexError) as err: except (ValueError, IndexError) as err:
raise web.HTTPBadRequest from err raise web.HTTPBadRequest from err
if not width or width != height or width not in VALID_SIZES:
raise web.HTTPBadRequest
image_info = self.image_collection.data.get(image_id) image_info = self.image_collection.data.get(image_id)
if image_info is None: if image_info is None:
@ -205,8 +216,33 @@ class ImageServeView(HomeAssistantView):
) )
def _generate_thumbnail(original_path, content_type, target_path, target_size): def _generate_thumbnail(
original_path: pathlib.Path,
content_type: str,
target_path: pathlib.Path,
target_size: tuple[int, int],
) -> None:
"""Generate a size.""" """Generate a size."""
image = ImageOps.exif_transpose(Image.open(original_path)) image = ImageOps.exif_transpose(Image.open(original_path))
image.thumbnail(target_size) image.thumbnail(target_size)
image.save(target_path, format=content_type.split("/", 1)[1]) image.save(target_path, format=content_type.partition("/")[-1])
def _validate_size_from_filename(filename: str) -> tuple[int, int]:
"""Parse image size from the given filename (of the form WIDTHxHEIGHT-filename).
>>> _validate_size_from_filename("100x100-image.png")
(100, 100)
>>> _validate_size_from_filename("jeff.png")
Traceback (most recent call last):
...
"""
image_size = filename.partition("-")[0]
if not image_size:
raise ValueError("Invalid filename")
width_s, _, height_s = image_size.partition("x")
width = int(width_s)
height = int(height_s)
if not width or width != height or width not in VALID_SIZES:
raise ValueError(f"Invalid size {image_size}")
return (width, height)

View File

@ -1263,6 +1263,16 @@ disallow_untyped_defs = true
warn_return_any = true warn_return_any = true
warn_unreachable = true warn_unreachable = true
[mypy-homeassistant.components.image.*]
check_untyped_defs = true
disallow_incomplete_defs = true
disallow_subclassing_any = true
disallow_untyped_calls = true
disallow_untyped_decorators = true
disallow_untyped_defs = true
warn_return_any = true
warn_unreachable = true
[mypy-homeassistant.components.image_processing.*] [mypy-homeassistant.components.image_processing.*]
check_untyped_defs = true check_untyped_defs = true
disallow_incomplete_defs = true disallow_incomplete_defs = true