Fix circular imports in core integrations (#111875)

* Fix circular imports in core integrations

* fix circular import

* fix more circular imports

* fix more circular imports

* fix more circular imports

* fix more circular imports

* fix more circular imports

* fix more circular imports

* fix more circular imports

* adjust

* fix

* increase timeout

* remove unused logger

* keep up to date

* make sure its reprod
This commit is contained in:
J. Nick Koston 2024-02-29 16:04:41 -10:00 committed by GitHub
parent 25510fc13c
commit c1750f7c3a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
9 changed files with 242 additions and 192 deletions

View File

@ -32,6 +32,11 @@ from homeassistant.core import Event, HomeAssistant
from homeassistant.exceptions import HomeAssistantError
from homeassistant.helpers import storage
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.http import (
KEY_AUTHENTICATED, # noqa: F401
HomeAssistantView,
current_request,
)
from homeassistant.helpers.network import NoURLAvailableError, get_url
from homeassistant.helpers.typing import ConfigType
from homeassistant.loader import bind_hass
@ -41,20 +46,14 @@ from homeassistant.util.json import json_loads
from .auth import async_setup_auth
from .ban import setup_bans
from .const import ( # noqa: F401
KEY_AUTHENTICATED,
KEY_HASS,
KEY_HASS_REFRESH_TOKEN_ID,
KEY_HASS_USER,
)
from .const import KEY_HASS, KEY_HASS_REFRESH_TOKEN_ID, KEY_HASS_USER # noqa: F401
from .cors import setup_cors
from .decorators import require_admin # noqa: F401
from .forwarded import async_setup_forwarded
from .headers import setup_headers
from .request_context import current_request, setup_request_context
from .request_context import setup_request_context
from .security_filter import setup_security_filter
from .static import CACHE_HEADERS, CachingStaticResource
from .view import HomeAssistantView
from .web_runner import HomeAssistantTCPSite
DOMAIN: Final = "http"

View File

@ -20,13 +20,13 @@ from homeassistant.auth.const import GROUP_ID_READ_ONLY
from homeassistant.auth.models import User
from homeassistant.components import websocket_api
from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.http import current_request
from homeassistant.helpers.json import json_bytes
from homeassistant.helpers.network import is_cloud_connection
from homeassistant.helpers.storage import Store
from homeassistant.util.network import is_local
from .const import KEY_AUTHENTICATED, KEY_HASS_REFRESH_TOKEN_ID, KEY_HASS_USER
from .request_context import current_request
_LOGGER = logging.getLogger(__name__)

View File

@ -15,7 +15,6 @@ from aiohttp.web import Application, Request, Response, StreamResponse, middlewa
from aiohttp.web_exceptions import HTTPForbidden, HTTPUnauthorized
import voluptuous as vol
from homeassistant.components import persistent_notification
from homeassistant.config import load_yaml_config_file
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError
@ -128,6 +127,10 @@ async def process_wrong_login(request: Request) -> None:
_LOGGER.warning(log_msg)
# Circular import with websocket_api
# pylint: disable=import-outside-toplevel
from homeassistant.components import persistent_notification
persistent_notification.async_create(
hass, notification_msg, "Login attempt failed", NOTIFICATION_ID_LOGIN
)

View File

@ -1,7 +1,8 @@
"""HTTP specific constants."""
from typing import Final
KEY_AUTHENTICATED: Final = "ha_authenticated"
from homeassistant.helpers.http import KEY_AUTHENTICATED # noqa: F401
KEY_HASS: Final = "hass"
KEY_HASS_USER: Final = "hass_user"
KEY_HASS_REFRESH_TOKEN_ID: Final = "hass_refresh_token_id"

View File

@ -7,10 +7,7 @@ from contextvars import ContextVar
from aiohttp.web import Application, Request, StreamResponse, middleware
from homeassistant.core import callback
current_request: ContextVar[Request | None] = ContextVar(
"current_request", default=None
)
from homeassistant.helpers.http import current_request # noqa: F401
@callback

View File

@ -1,180 +1,7 @@
"""Support for views."""
from __future__ import annotations
import asyncio
from collections.abc import Awaitable, Callable
from http import HTTPStatus
import logging
from typing import Any
from aiohttp import web
from aiohttp.typedefs import LooseHeaders
from aiohttp.web_exceptions import (
HTTPBadRequest,
HTTPInternalServerError,
HTTPUnauthorized,
from homeassistant.helpers.http import ( # noqa: F401
HomeAssistantView,
request_handler_factory,
)
from aiohttp.web_urldispatcher import AbstractRoute
import voluptuous as vol
from homeassistant import exceptions
from homeassistant.const import CONTENT_TYPE_JSON
from homeassistant.core import Context, HomeAssistant, is_callback
from homeassistant.helpers.json import (
find_paths_unserializable_data,
json_bytes,
json_dumps,
)
from homeassistant.util.json import JSON_ENCODE_EXCEPTIONS, format_unserializable_data
from .const import KEY_AUTHENTICATED
_LOGGER = logging.getLogger(__name__)
class HomeAssistantView:
"""Base view for all views."""
url: str | None = None
extra_urls: list[str] = []
# Views inheriting from this class can override this
requires_auth = True
cors_allowed = False
@staticmethod
def context(request: web.Request) -> Context:
"""Generate a context from a request."""
if (user := request.get("hass_user")) is None:
return Context()
return Context(user_id=user.id)
@staticmethod
def json(
result: Any,
status_code: HTTPStatus | int = HTTPStatus.OK,
headers: LooseHeaders | None = None,
) -> web.Response:
"""Return a JSON response."""
try:
msg = json_bytes(result)
except JSON_ENCODE_EXCEPTIONS as err:
_LOGGER.error(
"Unable to serialize to JSON. Bad data found at %s",
format_unserializable_data(
find_paths_unserializable_data(result, dump=json_dumps)
),
)
raise HTTPInternalServerError from err
response = web.Response(
body=msg,
content_type=CONTENT_TYPE_JSON,
status=int(status_code),
headers=headers,
zlib_executor_size=32768,
)
response.enable_compression()
return response
def json_message(
self,
message: str,
status_code: HTTPStatus | int = HTTPStatus.OK,
message_code: str | None = None,
headers: LooseHeaders | None = None,
) -> web.Response:
"""Return a JSON message response."""
data = {"message": message}
if message_code is not None:
data["code"] = message_code
return self.json(data, status_code, headers=headers)
def register(
self, hass: HomeAssistant, app: web.Application, router: web.UrlDispatcher
) -> None:
"""Register the view with a router."""
assert self.url is not None, "No url set for view"
urls = [self.url] + self.extra_urls
routes: list[AbstractRoute] = []
for method in ("get", "post", "delete", "put", "patch", "head", "options"):
if not (handler := getattr(self, method, None)):
continue
handler = request_handler_factory(hass, self, handler)
for url in urls:
routes.append(router.add_route(method, url, handler))
# Use `get` because CORS middleware is not be loaded in emulated_hue
if self.cors_allowed:
allow_cors = app.get("allow_all_cors")
else:
allow_cors = app.get("allow_configured_cors")
if allow_cors:
for route in routes:
allow_cors(route)
def request_handler_factory(
hass: HomeAssistant, view: HomeAssistantView, handler: Callable
) -> Callable[[web.Request], Awaitable[web.StreamResponse]]:
"""Wrap the handler classes."""
is_coroutinefunction = asyncio.iscoroutinefunction(handler)
assert is_coroutinefunction or is_callback(
handler
), "Handler should be a coroutine or a callback."
async def handle(request: web.Request) -> web.StreamResponse:
"""Handle incoming request."""
if hass.is_stopping:
return web.Response(status=HTTPStatus.SERVICE_UNAVAILABLE)
authenticated = request.get(KEY_AUTHENTICATED, False)
if view.requires_auth and not authenticated:
raise HTTPUnauthorized()
if _LOGGER.isEnabledFor(logging.DEBUG):
_LOGGER.debug(
"Serving %s to %s (auth: %s)",
request.path,
request.remote,
authenticated,
)
try:
if is_coroutinefunction:
result = await handler(request, **request.match_info)
else:
result = handler(request, **request.match_info)
except vol.Invalid as err:
raise HTTPBadRequest() from err
except exceptions.ServiceNotFound as err:
raise HTTPInternalServerError() from err
except exceptions.Unauthorized as err:
raise HTTPUnauthorized() from err
if isinstance(result, web.StreamResponse):
# The method handler returned a ready-made Response, how nice of it
return result
status_code = HTTPStatus.OK
if isinstance(result, tuple):
result, status_code = result
if isinstance(result, bytes):
return web.Response(body=result, status=status_code)
if isinstance(result, str):
return web.Response(text=result, status=status_code)
if result is None:
return web.Response(body=b"", status=status_code)
raise TypeError(
f"Result should be None, string, bytes or StreamResponse. Got: {result}"
)
return handle

View File

@ -9,9 +9,9 @@ from aiohttp import web
import voluptuous as vol
from homeassistant.auth.models import RefreshToken, User
from homeassistant.components.http import current_request
from homeassistant.core import Context, HomeAssistant, callback
from homeassistant.exceptions import HomeAssistantError, Unauthorized
from homeassistant.helpers.http import current_request
from homeassistant.util.json import JsonValueType
from . import const, messages

View File

@ -0,0 +1,184 @@
"""Helper to track the current http request."""
from __future__ import annotations
import asyncio
from collections.abc import Awaitable, Callable
from contextvars import ContextVar
from http import HTTPStatus
import logging
from typing import Any, Final
from aiohttp import web
from aiohttp.typedefs import LooseHeaders
from aiohttp.web import Request
from aiohttp.web_exceptions import (
HTTPBadRequest,
HTTPInternalServerError,
HTTPUnauthorized,
)
from aiohttp.web_urldispatcher import AbstractRoute
import voluptuous as vol
from homeassistant import exceptions
from homeassistant.const import CONTENT_TYPE_JSON
from homeassistant.core import Context, HomeAssistant, is_callback
from homeassistant.util.json import JSON_ENCODE_EXCEPTIONS, format_unserializable_data
from .json import find_paths_unserializable_data, json_bytes, json_dumps
_LOGGER = logging.getLogger(__name__)
KEY_AUTHENTICATED: Final = "ha_authenticated"
current_request: ContextVar[Request | None] = ContextVar(
"current_request", default=None
)
def request_handler_factory(
hass: HomeAssistant, view: HomeAssistantView, handler: Callable
) -> Callable[[web.Request], Awaitable[web.StreamResponse]]:
"""Wrap the handler classes."""
is_coroutinefunction = asyncio.iscoroutinefunction(handler)
assert is_coroutinefunction or is_callback(
handler
), "Handler should be a coroutine or a callback."
async def handle(request: web.Request) -> web.StreamResponse:
"""Handle incoming request."""
if hass.is_stopping:
return web.Response(status=HTTPStatus.SERVICE_UNAVAILABLE)
authenticated = request.get(KEY_AUTHENTICATED, False)
if view.requires_auth and not authenticated:
raise HTTPUnauthorized()
if _LOGGER.isEnabledFor(logging.DEBUG):
_LOGGER.debug(
"Serving %s to %s (auth: %s)",
request.path,
request.remote,
authenticated,
)
try:
if is_coroutinefunction:
result = await handler(request, **request.match_info)
else:
result = handler(request, **request.match_info)
except vol.Invalid as err:
raise HTTPBadRequest() from err
except exceptions.ServiceNotFound as err:
raise HTTPInternalServerError() from err
except exceptions.Unauthorized as err:
raise HTTPUnauthorized() from err
if isinstance(result, web.StreamResponse):
# The method handler returned a ready-made Response, how nice of it
return result
status_code = HTTPStatus.OK
if isinstance(result, tuple):
result, status_code = result
if isinstance(result, bytes):
return web.Response(body=result, status=status_code)
if isinstance(result, str):
return web.Response(text=result, status=status_code)
if result is None:
return web.Response(body=b"", status=status_code)
raise TypeError(
f"Result should be None, string, bytes or StreamResponse. Got: {result}"
)
return handle
class HomeAssistantView:
"""Base view for all views."""
url: str | None = None
extra_urls: list[str] = []
# Views inheriting from this class can override this
requires_auth = True
cors_allowed = False
@staticmethod
def context(request: web.Request) -> Context:
"""Generate a context from a request."""
if (user := request.get("hass_user")) is None:
return Context()
return Context(user_id=user.id)
@staticmethod
def json(
result: Any,
status_code: HTTPStatus | int = HTTPStatus.OK,
headers: LooseHeaders | None = None,
) -> web.Response:
"""Return a JSON response."""
try:
msg = json_bytes(result)
except JSON_ENCODE_EXCEPTIONS as err:
_LOGGER.error(
"Unable to serialize to JSON. Bad data found at %s",
format_unserializable_data(
find_paths_unserializable_data(result, dump=json_dumps)
),
)
raise HTTPInternalServerError from err
response = web.Response(
body=msg,
content_type=CONTENT_TYPE_JSON,
status=int(status_code),
headers=headers,
zlib_executor_size=32768,
)
response.enable_compression()
return response
def json_message(
self,
message: str,
status_code: HTTPStatus | int = HTTPStatus.OK,
message_code: str | None = None,
headers: LooseHeaders | None = None,
) -> web.Response:
"""Return a JSON message response."""
data = {"message": message}
if message_code is not None:
data["code"] = message_code
return self.json(data, status_code, headers=headers)
def register(
self, hass: HomeAssistant, app: web.Application, router: web.UrlDispatcher
) -> None:
"""Register the view with a router."""
assert self.url is not None, "No url set for view"
urls = [self.url] + self.extra_urls
routes: list[AbstractRoute] = []
for method in ("get", "post", "delete", "put", "patch", "head", "options"):
if not (handler := getattr(self, method, None)):
continue
handler = request_handler_factory(hass, self, handler)
for url in urls:
routes.append(router.add_route(method, url, handler))
# Use `get` because CORS middleware is not be loaded in emulated_hue
if self.cors_allowed:
allow_cors = app.get("allow_all_cors")
else:
allow_cors = app.get("allow_configured_cors")
if allow_cors:
for route in routes:
allow_cors(route)

View File

@ -0,0 +1,39 @@
"""Test to check for circular imports in core components."""
import asyncio
import sys
import pytest
from homeassistant.bootstrap import (
CORE_INTEGRATIONS,
DEBUGGER_INTEGRATIONS,
DEFAULT_INTEGRATIONS,
FRONTEND_INTEGRATIONS,
LOGGING_INTEGRATIONS,
RECORDER_INTEGRATIONS,
STAGE_1_INTEGRATIONS,
)
@pytest.mark.timeout(30) # cloud can take > 9s
@pytest.mark.parametrize(
"component",
sorted(
{
*DEBUGGER_INTEGRATIONS,
*CORE_INTEGRATIONS,
*LOGGING_INTEGRATIONS,
*FRONTEND_INTEGRATIONS,
*RECORDER_INTEGRATIONS,
*STAGE_1_INTEGRATIONS,
*DEFAULT_INTEGRATIONS,
}
),
)
async def test_circular_imports(component: str) -> None:
"""Check that components can be imported without circular imports."""
process = await asyncio.create_subprocess_exec(
sys.executable, "-c", f"import homeassistant.components.{component}"
)
await process.communicate()
assert process.returncode == 0