Compare commits

..

2 Commits

Author SHA1 Message Date
Stefan Agner
6f9ad37b16 Fix type annotations in addon options validation
Add missing type annotations to AddonOptions and UiOptions classes:
- Add parameter and return type to AddonOptions.__call__
- Add explicit type annotation to UiOptions.coresys attribute
- Add return type to UiOptions._ui_schema_element method

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-12-03 21:35:59 +01:00
Stefan Agner
57e0b14f25 Fix type annotations in addon options validation
The type annotations for validation methods in AddonOptions and
UiOptions were overly restrictive and did not match runtime behavior:

- _nested_validate_list and _nested_validate_dict receive user input
  that could be any type, with runtime isinstance checks to validate.
  Changed parameter types from list[Any]/dict[Any, Any] to Any.

- _ui_schema_element handles str, list, and dict values depending on
  the schema structure. Changed from str to the union type.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-12-03 21:34:40 +01:00
12 changed files with 46 additions and 83 deletions

View File

@@ -75,7 +75,7 @@ class AddonOptions(CoreSysAttributes):
"""Create a schema for add-on options."""
return vol.Schema(vol.All(dict, self))
def __call__(self, struct):
def __call__(self, struct: dict[str, Any]) -> dict[str, Any]:
"""Create schema validator for add-ons options."""
options = {}
@@ -193,9 +193,7 @@ class AddonOptions(CoreSysAttributes):
f"Fatal error for option '{key}' with type '{typ}' in {self._name} ({self._slug})"
) from None
def _nested_validate_list(
self, typ: Any, data_list: list[Any], key: str
) -> list[Any]:
def _nested_validate_list(self, typ: Any, data_list: Any, key: str) -> list[Any]:
"""Validate nested items."""
options = []
@@ -213,7 +211,7 @@ class AddonOptions(CoreSysAttributes):
return options
def _nested_validate_dict(
self, typ: dict[Any, Any], data_dict: dict[Any, Any], key: str
self, typ: dict[Any, Any], data_dict: Any, key: str
) -> dict[Any, Any]:
"""Validate nested items."""
options = {}
@@ -264,7 +262,7 @@ class UiOptions(CoreSysAttributes):
def __init__(self, coresys: CoreSys) -> None:
"""Initialize UI option render."""
self.coresys = coresys
self.coresys: CoreSys = coresys
def __call__(self, raw_schema: dict[str, Any]) -> list[dict[str, Any]]:
"""Generate UI schema."""
@@ -279,10 +277,10 @@ class UiOptions(CoreSysAttributes):
def _ui_schema_element(
self,
ui_schema: list[dict[str, Any]],
value: str,
value: str | list[Any] | dict[str, Any],
key: str,
multiple: bool = False,
):
) -> None:
if isinstance(value, list):
# nested value list
assert not multiple

View File

@@ -390,7 +390,7 @@ class APIAddons(CoreSysAttributes):
return data
@api_process
async def options_config(self, request: web.Request) -> dict[str, Any]:
async def options_config(self, request: web.Request) -> None:
"""Validate user options for add-on."""
slug: str = request.match_info["addon"]
if slug != "self":
@@ -435,11 +435,11 @@ class APIAddons(CoreSysAttributes):
}
@api_process
async def uninstall(self, request: web.Request) -> None:
async def uninstall(self, request: web.Request) -> Awaitable[None]:
"""Uninstall add-on."""
addon = self.get_addon_for_request(request)
body: dict[str, Any] = await api_validate(SCHEMA_UNINSTALL, request)
await asyncio.shield(
return await asyncio.shield(
self.sys_addons.uninstall(
addon.slug, remove_config=body[ATTR_REMOVE_CONFIG]
)

View File

@@ -211,7 +211,7 @@ class APIBackups(CoreSysAttributes):
await self.sys_backups.save_data()
@api_process
async def reload(self, _: web.Request) -> bool:
async def reload(self, _):
"""Reload backup list."""
await asyncio.shield(self.sys_backups.reload())
return True
@@ -421,7 +421,7 @@ class APIBackups(CoreSysAttributes):
await self.sys_backups.remove(backup, locations=locations)
@api_process
async def download(self, request: web.Request) -> web.StreamResponse:
async def download(self, request: web.Request):
"""Download a backup file."""
backup = self._extract_slug(request)
# Query will give us '' for /backups, convert value to None
@@ -451,7 +451,7 @@ class APIBackups(CoreSysAttributes):
return response
@api_process
async def upload(self, request: web.Request) -> dict[str, str] | bool:
async def upload(self, request: web.Request):
"""Upload a backup file."""
location: LOCATION_TYPE = None
locations: list[LOCATION_TYPE] | None = None

View File

@@ -55,7 +55,7 @@ class APIDocker(CoreSysAttributes):
"""Handle RESTful API for Docker configuration."""
@api_process
async def info(self, request: web.Request) -> dict[str, Any]:
async def info(self, request: web.Request):
"""Get docker info."""
data_registries = {}
for hostname, registry in self.sys_docker.config.registries.items():
@@ -113,7 +113,7 @@ class APIDocker(CoreSysAttributes):
return {ATTR_REGISTRIES: data_registries}
@api_process
async def create_registry(self, request: web.Request) -> None:
async def create_registry(self, request: web.Request):
"""Create a new docker registry."""
body = await api_validate(SCHEMA_DOCKER_REGISTRY, request)
@@ -123,7 +123,7 @@ class APIDocker(CoreSysAttributes):
await self.sys_docker.config.save_data()
@api_process
async def remove_registry(self, request: web.Request) -> None:
async def remove_registry(self, request: web.Request):
"""Delete a docker registry."""
hostname = request.match_info.get(ATTR_HOSTNAME)
if hostname not in self.sys_docker.config.registries:

View File

@@ -154,7 +154,7 @@ class APIHomeAssistant(CoreSysAttributes):
await self.sys_homeassistant.save_data()
@api_process
async def stats(self, request: web.Request) -> dict[str, Any]:
async def stats(self, request: web.Request) -> dict[Any, str]:
"""Return resource information."""
stats = await self.sys_homeassistant.core.stats()
if not stats:
@@ -191,7 +191,7 @@ class APIHomeAssistant(CoreSysAttributes):
return await update_task
@api_process
async def stop(self, request: web.Request) -> None:
async def stop(self, request: web.Request) -> Awaitable[None]:
"""Stop Home Assistant."""
body = await api_validate(SCHEMA_STOP, request)
await self._check_offline_migration(force=body[ATTR_FORCE])

View File

@@ -1,7 +1,6 @@
"""Init file for Supervisor host RESTful API."""
import asyncio
from collections.abc import Awaitable
from contextlib import suppress
import json
import logging
@@ -100,7 +99,7 @@ class APIHost(CoreSysAttributes):
)
@api_process
async def info(self, request: web.Request) -> dict[str, Any]:
async def info(self, request):
"""Return host information."""
return {
ATTR_AGENT_VERSION: self.sys_dbus.agent.version,
@@ -129,7 +128,7 @@ class APIHost(CoreSysAttributes):
}
@api_process
async def options(self, request: web.Request) -> None:
async def options(self, request):
"""Edit host settings."""
body = await api_validate(SCHEMA_OPTIONS, request)
@@ -140,7 +139,7 @@ class APIHost(CoreSysAttributes):
)
@api_process
async def reboot(self, request: web.Request) -> None:
async def reboot(self, request):
"""Reboot host."""
body = await api_validate(SCHEMA_SHUTDOWN, request)
await self._check_ha_offline_migration(force=body[ATTR_FORCE])
@@ -148,7 +147,7 @@ class APIHost(CoreSysAttributes):
return await asyncio.shield(self.sys_host.control.reboot())
@api_process
async def shutdown(self, request: web.Request) -> None:
async def shutdown(self, request):
"""Poweroff host."""
body = await api_validate(SCHEMA_SHUTDOWN, request)
await self._check_ha_offline_migration(force=body[ATTR_FORCE])
@@ -156,12 +155,12 @@ class APIHost(CoreSysAttributes):
return await asyncio.shield(self.sys_host.control.shutdown())
@api_process
def reload(self, request: web.Request) -> Awaitable[None]:
def reload(self, request):
"""Reload host data."""
return asyncio.shield(self.sys_host.reload())
@api_process
async def services(self, request: web.Request) -> dict[str, Any]:
async def services(self, request):
"""Return list of available services."""
services = []
for unit in self.sys_host.services:
@@ -176,7 +175,7 @@ class APIHost(CoreSysAttributes):
return {ATTR_SERVICES: services}
@api_process
async def list_boots(self, _: web.Request) -> dict[str, Any]:
async def list_boots(self, _: web.Request):
"""Return a list of boot IDs."""
boot_ids = await self.sys_host.logs.get_boot_ids()
return {
@@ -187,7 +186,7 @@ class APIHost(CoreSysAttributes):
}
@api_process
async def list_identifiers(self, _: web.Request) -> dict[str, list[str]]:
async def list_identifiers(self, _: web.Request):
"""Return a list of syslog identifiers."""
return {ATTR_IDENTIFIERS: await self.sys_host.logs.get_identifiers()}
@@ -333,7 +332,7 @@ class APIHost(CoreSysAttributes):
)
@api_process
async def disk_usage(self, request: web.Request) -> dict[str, Any]:
async def disk_usage(self, request: web.Request) -> dict:
"""Return a breakdown of storage usage for the system."""
max_depth = request.query.get(ATTR_MAX_DEPTH, 1)

View File

@@ -1,12 +1,12 @@
"""Handle security part of this API."""
from collections.abc import Awaitable, Callable
from collections.abc import Callable
import logging
import re
from typing import Final
from urllib.parse import unquote
from aiohttp.web import Request, StreamResponse, middleware
from aiohttp.web import Request, Response, middleware
from aiohttp.web_exceptions import HTTPBadRequest, HTTPForbidden, HTTPUnauthorized
from awesomeversion import AwesomeVersion
@@ -89,7 +89,7 @@ CORE_ONLY_PATHS: Final = re.compile(
)
# Policy role add-on API access
ADDONS_ROLE_ACCESS: dict[str, re.Pattern[str]] = {
ADDONS_ROLE_ACCESS: dict[str, re.Pattern] = {
ROLE_DEFAULT: re.compile(
r"^(?:"
r"|/.+/info"
@@ -180,9 +180,7 @@ class SecurityMiddleware(CoreSysAttributes):
return unquoted
@middleware
async def block_bad_requests(
self, request: Request, handler: Callable[[Request], Awaitable[StreamResponse]]
) -> StreamResponse:
async def block_bad_requests(self, request: Request, handler: Callable) -> Response:
"""Process request and tblock commonly known exploit attempts."""
if FILTERS.search(self._recursive_unquote(request.path)):
_LOGGER.warning(
@@ -200,9 +198,7 @@ class SecurityMiddleware(CoreSysAttributes):
return await handler(request)
@middleware
async def system_validation(
self, request: Request, handler: Callable[[Request], Awaitable[StreamResponse]]
) -> StreamResponse:
async def system_validation(self, request: Request, handler: Callable) -> Response:
"""Check if core is ready to response."""
if self.sys_core.state not in VALID_API_STATES:
return api_return_error(
@@ -212,9 +208,7 @@ class SecurityMiddleware(CoreSysAttributes):
return await handler(request)
@middleware
async def token_validation(
self, request: Request, handler: Callable[[Request], Awaitable[StreamResponse]]
) -> StreamResponse:
async def token_validation(self, request: Request, handler: Callable) -> Response:
"""Check security access of this layer."""
request_from: CoreSysAttributes | None = None
supervisor_token = extract_supervisor_token(request)
@@ -285,9 +279,7 @@ class SecurityMiddleware(CoreSysAttributes):
raise HTTPForbidden()
@middleware
async def core_proxy(
self, request: Request, handler: Callable[[Request], Awaitable[StreamResponse]]
) -> StreamResponse:
async def core_proxy(self, request: Request, handler: Callable) -> Response:
"""Validate user from Core API proxy."""
if (
request[REQUEST_FROM] != self.sys_homeassistant

View File

@@ -1,9 +1,5 @@
"""Init file for Supervisor network RESTful API."""
from typing import Any
from aiohttp import web
from ..const import (
ATTR_AVAILABLE,
ATTR_PROVIDERS,
@@ -29,7 +25,7 @@ class APIServices(CoreSysAttributes):
return service
@api_process
async def list_services(self, request: web.Request) -> dict[str, Any]:
async def list_services(self, request):
"""Show register services."""
services = []
for service in self.sys_services.list_services:
@@ -44,7 +40,7 @@ class APIServices(CoreSysAttributes):
return {ATTR_SERVICES: services}
@api_process
async def set_service(self, request: web.Request) -> None:
async def set_service(self, request):
"""Write data into a service."""
service = self._extract_service(request)
body = await api_validate(service.schema, request)
@@ -54,7 +50,7 @@ class APIServices(CoreSysAttributes):
await service.set_service_data(addon, body)
@api_process
async def get_service(self, request: web.Request) -> dict[str, Any]:
async def get_service(self, request):
"""Read data into a service."""
service = self._extract_service(request)
@@ -66,7 +62,7 @@ class APIServices(CoreSysAttributes):
return service.get_service_data()
@api_process
async def del_service(self, request: web.Request) -> None:
async def del_service(self, request):
"""Delete data into a service."""
service = self._extract_service(request)
addon = request[REQUEST_FROM]

View File

@@ -349,13 +349,13 @@ class APIStore(CoreSysAttributes):
return self._generate_repository_information(repository)
@api_process
async def add_repository(self, request: web.Request) -> None:
async def add_repository(self, request: web.Request):
"""Add repository to the store."""
body = await api_validate(SCHEMA_ADD_REPOSITORY, request)
await asyncio.shield(self.sys_store.add_repository(body[ATTR_REPOSITORY]))
@api_process
async def remove_repository(self, request: web.Request) -> None:
async def remove_repository(self, request: web.Request):
"""Remove repository from the store."""
repository: Repository = self._extract_repository(request)
await asyncio.shield(self.sys_store.remove_repository(repository))

View File

@@ -80,7 +80,7 @@ class APISupervisor(CoreSysAttributes):
"""Handle RESTful API for Supervisor functions."""
@api_process
async def ping(self, request: web.Request) -> bool:
async def ping(self, request):
"""Return ok for signal that the API is ready."""
return True

View File

@@ -474,10 +474,8 @@ class DockerAPI(CoreSysAttributes):
raises only if the get fails afterwards. Additionally it fires progress reports for the pull
on the bus so listeners can use that to update status for users.
"""
# Use timeout=None to disable timeout for pull operations, matching docker-py behavior.
# aiodocker converts None to ClientTimeout(total=None) which disables the timeout.
async for e in self.images.pull(
repository, tag=tag, platform=platform, auth=auth, stream=True, timeout=None
repository, tag=tag, platform=platform, auth=auth, stream=True
):
entry = PullLogEntry.from_pull_log_dict(job_id, e)
if entry.error:

View File

@@ -54,7 +54,7 @@ async def test_docker_image_platform(
coresys.docker.images.inspect.return_value = {"Id": "test:1.2.3"}
await test_docker_interface.install(AwesomeVersion("1.2.3"), "test", arch=cpu_arch)
coresys.docker.images.pull.assert_called_once_with(
"test", tag="1.2.3", platform=platform, auth=None, stream=True, timeout=None
"test", tag="1.2.3", platform=platform, auth=None, stream=True
)
coresys.docker.images.inspect.assert_called_once_with("test:1.2.3")
@@ -71,12 +71,7 @@ async def test_docker_image_default_platform(
):
await test_docker_interface.install(AwesomeVersion("1.2.3"), "test")
coresys.docker.images.pull.assert_called_once_with(
"test",
tag="1.2.3",
platform="linux/386",
auth=None,
stream=True,
timeout=None,
"test", tag="1.2.3", platform="linux/386", auth=None, stream=True
)
coresys.docker.images.inspect.assert_called_once_with("test:1.2.3")
@@ -116,12 +111,7 @@ async def test_private_registry_credentials_passed_to_pull(
expected_auth["registry"] = registry_key
coresys.docker.images.pull.assert_called_once_with(
image,
tag="1.2.3",
platform="linux/amd64",
auth=expected_auth,
stream=True,
timeout=None,
image, tag="1.2.3", platform="linux/amd64", auth=expected_auth, stream=True
)
@@ -370,12 +360,7 @@ async def test_install_fires_progress_events(
):
await test_docker_interface.install(AwesomeVersion("1.2.3"), "test")
coresys.docker.images.pull.assert_called_once_with(
"test",
tag="1.2.3",
platform="linux/386",
auth=None,
stream=True,
timeout=None,
"test", tag="1.2.3", platform="linux/386", auth=None, stream=True
)
coresys.docker.images.inspect.assert_called_once_with("test:1.2.3")
@@ -832,12 +817,7 @@ async def test_install_progress_containerd_snapshot(
with patch.object(Supervisor, "arch", PropertyMock(return_value="i386")):
await test_docker_interface.mock_install()
coresys.docker.images.pull.assert_called_once_with(
"test",
tag="1.2.3",
platform="linux/386",
auth=None,
stream=True,
timeout=None,
"test", tag="1.2.3", platform="linux/386", auth=None, stream=True
)
coresys.docker.images.inspect.assert_called_once_with("test:1.2.3")