mirror of
https://github.com/home-assistant/core.git
synced 2025-04-23 16:57:53 +00:00
Fix circular imports in core integrations (#111875)
* Fix circular imports in core integrations * fix circular import * fix more circular imports * fix more circular imports * fix more circular imports * fix more circular imports * fix more circular imports * fix more circular imports * fix more circular imports * adjust * fix * increase timeout * remove unused logger * keep up to date * make sure its reprod
This commit is contained in:
parent
25510fc13c
commit
c1750f7c3a
@ -32,6 +32,11 @@ from homeassistant.core import Event, HomeAssistant
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.helpers import storage
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.http import (
|
||||
KEY_AUTHENTICATED, # noqa: F401
|
||||
HomeAssistantView,
|
||||
current_request,
|
||||
)
|
||||
from homeassistant.helpers.network import NoURLAvailableError, get_url
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
from homeassistant.loader import bind_hass
|
||||
@ -41,20 +46,14 @@ from homeassistant.util.json import json_loads
|
||||
|
||||
from .auth import async_setup_auth
|
||||
from .ban import setup_bans
|
||||
from .const import ( # noqa: F401
|
||||
KEY_AUTHENTICATED,
|
||||
KEY_HASS,
|
||||
KEY_HASS_REFRESH_TOKEN_ID,
|
||||
KEY_HASS_USER,
|
||||
)
|
||||
from .const import KEY_HASS, KEY_HASS_REFRESH_TOKEN_ID, KEY_HASS_USER # noqa: F401
|
||||
from .cors import setup_cors
|
||||
from .decorators import require_admin # noqa: F401
|
||||
from .forwarded import async_setup_forwarded
|
||||
from .headers import setup_headers
|
||||
from .request_context import current_request, setup_request_context
|
||||
from .request_context import setup_request_context
|
||||
from .security_filter import setup_security_filter
|
||||
from .static import CACHE_HEADERS, CachingStaticResource
|
||||
from .view import HomeAssistantView
|
||||
from .web_runner import HomeAssistantTCPSite
|
||||
|
||||
DOMAIN: Final = "http"
|
||||
|
@ -20,13 +20,13 @@ from homeassistant.auth.const import GROUP_ID_READ_ONLY
|
||||
from homeassistant.auth.models import User
|
||||
from homeassistant.components import websocket_api
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.helpers.http import current_request
|
||||
from homeassistant.helpers.json import json_bytes
|
||||
from homeassistant.helpers.network import is_cloud_connection
|
||||
from homeassistant.helpers.storage import Store
|
||||
from homeassistant.util.network import is_local
|
||||
|
||||
from .const import KEY_AUTHENTICATED, KEY_HASS_REFRESH_TOKEN_ID, KEY_HASS_USER
|
||||
from .request_context import current_request
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
@ -15,7 +15,6 @@ from aiohttp.web import Application, Request, Response, StreamResponse, middlewa
|
||||
from aiohttp.web_exceptions import HTTPForbidden, HTTPUnauthorized
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components import persistent_notification
|
||||
from homeassistant.config import load_yaml_config_file
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
@ -128,6 +127,10 @@ async def process_wrong_login(request: Request) -> None:
|
||||
|
||||
_LOGGER.warning(log_msg)
|
||||
|
||||
# Circular import with websocket_api
|
||||
# pylint: disable=import-outside-toplevel
|
||||
from homeassistant.components import persistent_notification
|
||||
|
||||
persistent_notification.async_create(
|
||||
hass, notification_msg, "Login attempt failed", NOTIFICATION_ID_LOGIN
|
||||
)
|
||||
|
@ -1,7 +1,8 @@
|
||||
"""HTTP specific constants."""
|
||||
from typing import Final
|
||||
|
||||
KEY_AUTHENTICATED: Final = "ha_authenticated"
|
||||
from homeassistant.helpers.http import KEY_AUTHENTICATED # noqa: F401
|
||||
|
||||
KEY_HASS: Final = "hass"
|
||||
KEY_HASS_USER: Final = "hass_user"
|
||||
KEY_HASS_REFRESH_TOKEN_ID: Final = "hass_refresh_token_id"
|
||||
|
@ -7,10 +7,7 @@ from contextvars import ContextVar
|
||||
from aiohttp.web import Application, Request, StreamResponse, middleware
|
||||
|
||||
from homeassistant.core import callback
|
||||
|
||||
current_request: ContextVar[Request | None] = ContextVar(
|
||||
"current_request", default=None
|
||||
)
|
||||
from homeassistant.helpers.http import current_request # noqa: F401
|
||||
|
||||
|
||||
@callback
|
||||
|
@ -1,180 +1,7 @@
|
||||
"""Support for views."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from http import HTTPStatus
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from aiohttp import web
|
||||
from aiohttp.typedefs import LooseHeaders
|
||||
from aiohttp.web_exceptions import (
|
||||
HTTPBadRequest,
|
||||
HTTPInternalServerError,
|
||||
HTTPUnauthorized,
|
||||
from homeassistant.helpers.http import ( # noqa: F401
|
||||
HomeAssistantView,
|
||||
request_handler_factory,
|
||||
)
|
||||
from aiohttp.web_urldispatcher import AbstractRoute
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import exceptions
|
||||
from homeassistant.const import CONTENT_TYPE_JSON
|
||||
from homeassistant.core import Context, HomeAssistant, is_callback
|
||||
from homeassistant.helpers.json import (
|
||||
find_paths_unserializable_data,
|
||||
json_bytes,
|
||||
json_dumps,
|
||||
)
|
||||
from homeassistant.util.json import JSON_ENCODE_EXCEPTIONS, format_unserializable_data
|
||||
|
||||
from .const import KEY_AUTHENTICATED
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HomeAssistantView:
|
||||
"""Base view for all views."""
|
||||
|
||||
url: str | None = None
|
||||
extra_urls: list[str] = []
|
||||
# Views inheriting from this class can override this
|
||||
requires_auth = True
|
||||
cors_allowed = False
|
||||
|
||||
@staticmethod
|
||||
def context(request: web.Request) -> Context:
|
||||
"""Generate a context from a request."""
|
||||
if (user := request.get("hass_user")) is None:
|
||||
return Context()
|
||||
|
||||
return Context(user_id=user.id)
|
||||
|
||||
@staticmethod
|
||||
def json(
|
||||
result: Any,
|
||||
status_code: HTTPStatus | int = HTTPStatus.OK,
|
||||
headers: LooseHeaders | None = None,
|
||||
) -> web.Response:
|
||||
"""Return a JSON response."""
|
||||
try:
|
||||
msg = json_bytes(result)
|
||||
except JSON_ENCODE_EXCEPTIONS as err:
|
||||
_LOGGER.error(
|
||||
"Unable to serialize to JSON. Bad data found at %s",
|
||||
format_unserializable_data(
|
||||
find_paths_unserializable_data(result, dump=json_dumps)
|
||||
),
|
||||
)
|
||||
raise HTTPInternalServerError from err
|
||||
response = web.Response(
|
||||
body=msg,
|
||||
content_type=CONTENT_TYPE_JSON,
|
||||
status=int(status_code),
|
||||
headers=headers,
|
||||
zlib_executor_size=32768,
|
||||
)
|
||||
response.enable_compression()
|
||||
return response
|
||||
|
||||
def json_message(
|
||||
self,
|
||||
message: str,
|
||||
status_code: HTTPStatus | int = HTTPStatus.OK,
|
||||
message_code: str | None = None,
|
||||
headers: LooseHeaders | None = None,
|
||||
) -> web.Response:
|
||||
"""Return a JSON message response."""
|
||||
data = {"message": message}
|
||||
if message_code is not None:
|
||||
data["code"] = message_code
|
||||
return self.json(data, status_code, headers=headers)
|
||||
|
||||
def register(
|
||||
self, hass: HomeAssistant, app: web.Application, router: web.UrlDispatcher
|
||||
) -> None:
|
||||
"""Register the view with a router."""
|
||||
assert self.url is not None, "No url set for view"
|
||||
urls = [self.url] + self.extra_urls
|
||||
routes: list[AbstractRoute] = []
|
||||
|
||||
for method in ("get", "post", "delete", "put", "patch", "head", "options"):
|
||||
if not (handler := getattr(self, method, None)):
|
||||
continue
|
||||
|
||||
handler = request_handler_factory(hass, self, handler)
|
||||
|
||||
for url in urls:
|
||||
routes.append(router.add_route(method, url, handler))
|
||||
|
||||
# Use `get` because CORS middleware is not be loaded in emulated_hue
|
||||
if self.cors_allowed:
|
||||
allow_cors = app.get("allow_all_cors")
|
||||
else:
|
||||
allow_cors = app.get("allow_configured_cors")
|
||||
|
||||
if allow_cors:
|
||||
for route in routes:
|
||||
allow_cors(route)
|
||||
|
||||
|
||||
def request_handler_factory(
|
||||
hass: HomeAssistant, view: HomeAssistantView, handler: Callable
|
||||
) -> Callable[[web.Request], Awaitable[web.StreamResponse]]:
|
||||
"""Wrap the handler classes."""
|
||||
is_coroutinefunction = asyncio.iscoroutinefunction(handler)
|
||||
assert is_coroutinefunction or is_callback(
|
||||
handler
|
||||
), "Handler should be a coroutine or a callback."
|
||||
|
||||
async def handle(request: web.Request) -> web.StreamResponse:
|
||||
"""Handle incoming request."""
|
||||
if hass.is_stopping:
|
||||
return web.Response(status=HTTPStatus.SERVICE_UNAVAILABLE)
|
||||
|
||||
authenticated = request.get(KEY_AUTHENTICATED, False)
|
||||
|
||||
if view.requires_auth and not authenticated:
|
||||
raise HTTPUnauthorized()
|
||||
|
||||
if _LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_LOGGER.debug(
|
||||
"Serving %s to %s (auth: %s)",
|
||||
request.path,
|
||||
request.remote,
|
||||
authenticated,
|
||||
)
|
||||
|
||||
try:
|
||||
if is_coroutinefunction:
|
||||
result = await handler(request, **request.match_info)
|
||||
else:
|
||||
result = handler(request, **request.match_info)
|
||||
except vol.Invalid as err:
|
||||
raise HTTPBadRequest() from err
|
||||
except exceptions.ServiceNotFound as err:
|
||||
raise HTTPInternalServerError() from err
|
||||
except exceptions.Unauthorized as err:
|
||||
raise HTTPUnauthorized() from err
|
||||
|
||||
if isinstance(result, web.StreamResponse):
|
||||
# The method handler returned a ready-made Response, how nice of it
|
||||
return result
|
||||
|
||||
status_code = HTTPStatus.OK
|
||||
if isinstance(result, tuple):
|
||||
result, status_code = result
|
||||
|
||||
if isinstance(result, bytes):
|
||||
return web.Response(body=result, status=status_code)
|
||||
|
||||
if isinstance(result, str):
|
||||
return web.Response(text=result, status=status_code)
|
||||
|
||||
if result is None:
|
||||
return web.Response(body=b"", status=status_code)
|
||||
|
||||
raise TypeError(
|
||||
f"Result should be None, string, bytes or StreamResponse. Got: {result}"
|
||||
)
|
||||
|
||||
return handle
|
||||
|
@ -9,9 +9,9 @@ from aiohttp import web
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.auth.models import RefreshToken, User
|
||||
from homeassistant.components.http import current_request
|
||||
from homeassistant.core import Context, HomeAssistant, callback
|
||||
from homeassistant.exceptions import HomeAssistantError, Unauthorized
|
||||
from homeassistant.helpers.http import current_request
|
||||
from homeassistant.util.json import JsonValueType
|
||||
|
||||
from . import const, messages
|
||||
|
184
homeassistant/helpers/http.py
Normal file
184
homeassistant/helpers/http.py
Normal file
@ -0,0 +1,184 @@
|
||||
"""Helper to track the current http request."""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
from contextvars import ContextVar
|
||||
from http import HTTPStatus
|
||||
import logging
|
||||
from typing import Any, Final
|
||||
|
||||
from aiohttp import web
|
||||
from aiohttp.typedefs import LooseHeaders
|
||||
from aiohttp.web import Request
|
||||
from aiohttp.web_exceptions import (
|
||||
HTTPBadRequest,
|
||||
HTTPInternalServerError,
|
||||
HTTPUnauthorized,
|
||||
)
|
||||
from aiohttp.web_urldispatcher import AbstractRoute
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant import exceptions
|
||||
from homeassistant.const import CONTENT_TYPE_JSON
|
||||
from homeassistant.core import Context, HomeAssistant, is_callback
|
||||
from homeassistant.util.json import JSON_ENCODE_EXCEPTIONS, format_unserializable_data
|
||||
|
||||
from .json import find_paths_unserializable_data, json_bytes, json_dumps
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
KEY_AUTHENTICATED: Final = "ha_authenticated"
|
||||
|
||||
current_request: ContextVar[Request | None] = ContextVar(
|
||||
"current_request", default=None
|
||||
)
|
||||
|
||||
|
||||
def request_handler_factory(
|
||||
hass: HomeAssistant, view: HomeAssistantView, handler: Callable
|
||||
) -> Callable[[web.Request], Awaitable[web.StreamResponse]]:
|
||||
"""Wrap the handler classes."""
|
||||
is_coroutinefunction = asyncio.iscoroutinefunction(handler)
|
||||
assert is_coroutinefunction or is_callback(
|
||||
handler
|
||||
), "Handler should be a coroutine or a callback."
|
||||
|
||||
async def handle(request: web.Request) -> web.StreamResponse:
|
||||
"""Handle incoming request."""
|
||||
if hass.is_stopping:
|
||||
return web.Response(status=HTTPStatus.SERVICE_UNAVAILABLE)
|
||||
|
||||
authenticated = request.get(KEY_AUTHENTICATED, False)
|
||||
|
||||
if view.requires_auth and not authenticated:
|
||||
raise HTTPUnauthorized()
|
||||
|
||||
if _LOGGER.isEnabledFor(logging.DEBUG):
|
||||
_LOGGER.debug(
|
||||
"Serving %s to %s (auth: %s)",
|
||||
request.path,
|
||||
request.remote,
|
||||
authenticated,
|
||||
)
|
||||
|
||||
try:
|
||||
if is_coroutinefunction:
|
||||
result = await handler(request, **request.match_info)
|
||||
else:
|
||||
result = handler(request, **request.match_info)
|
||||
except vol.Invalid as err:
|
||||
raise HTTPBadRequest() from err
|
||||
except exceptions.ServiceNotFound as err:
|
||||
raise HTTPInternalServerError() from err
|
||||
except exceptions.Unauthorized as err:
|
||||
raise HTTPUnauthorized() from err
|
||||
|
||||
if isinstance(result, web.StreamResponse):
|
||||
# The method handler returned a ready-made Response, how nice of it
|
||||
return result
|
||||
|
||||
status_code = HTTPStatus.OK
|
||||
if isinstance(result, tuple):
|
||||
result, status_code = result
|
||||
|
||||
if isinstance(result, bytes):
|
||||
return web.Response(body=result, status=status_code)
|
||||
|
||||
if isinstance(result, str):
|
||||
return web.Response(text=result, status=status_code)
|
||||
|
||||
if result is None:
|
||||
return web.Response(body=b"", status=status_code)
|
||||
|
||||
raise TypeError(
|
||||
f"Result should be None, string, bytes or StreamResponse. Got: {result}"
|
||||
)
|
||||
|
||||
return handle
|
||||
|
||||
|
||||
class HomeAssistantView:
|
||||
"""Base view for all views."""
|
||||
|
||||
url: str | None = None
|
||||
extra_urls: list[str] = []
|
||||
# Views inheriting from this class can override this
|
||||
requires_auth = True
|
||||
cors_allowed = False
|
||||
|
||||
@staticmethod
|
||||
def context(request: web.Request) -> Context:
|
||||
"""Generate a context from a request."""
|
||||
if (user := request.get("hass_user")) is None:
|
||||
return Context()
|
||||
|
||||
return Context(user_id=user.id)
|
||||
|
||||
@staticmethod
|
||||
def json(
|
||||
result: Any,
|
||||
status_code: HTTPStatus | int = HTTPStatus.OK,
|
||||
headers: LooseHeaders | None = None,
|
||||
) -> web.Response:
|
||||
"""Return a JSON response."""
|
||||
try:
|
||||
msg = json_bytes(result)
|
||||
except JSON_ENCODE_EXCEPTIONS as err:
|
||||
_LOGGER.error(
|
||||
"Unable to serialize to JSON. Bad data found at %s",
|
||||
format_unserializable_data(
|
||||
find_paths_unserializable_data(result, dump=json_dumps)
|
||||
),
|
||||
)
|
||||
raise HTTPInternalServerError from err
|
||||
response = web.Response(
|
||||
body=msg,
|
||||
content_type=CONTENT_TYPE_JSON,
|
||||
status=int(status_code),
|
||||
headers=headers,
|
||||
zlib_executor_size=32768,
|
||||
)
|
||||
response.enable_compression()
|
||||
return response
|
||||
|
||||
def json_message(
|
||||
self,
|
||||
message: str,
|
||||
status_code: HTTPStatus | int = HTTPStatus.OK,
|
||||
message_code: str | None = None,
|
||||
headers: LooseHeaders | None = None,
|
||||
) -> web.Response:
|
||||
"""Return a JSON message response."""
|
||||
data = {"message": message}
|
||||
if message_code is not None:
|
||||
data["code"] = message_code
|
||||
return self.json(data, status_code, headers=headers)
|
||||
|
||||
def register(
|
||||
self, hass: HomeAssistant, app: web.Application, router: web.UrlDispatcher
|
||||
) -> None:
|
||||
"""Register the view with a router."""
|
||||
assert self.url is not None, "No url set for view"
|
||||
urls = [self.url] + self.extra_urls
|
||||
routes: list[AbstractRoute] = []
|
||||
|
||||
for method in ("get", "post", "delete", "put", "patch", "head", "options"):
|
||||
if not (handler := getattr(self, method, None)):
|
||||
continue
|
||||
|
||||
handler = request_handler_factory(hass, self, handler)
|
||||
|
||||
for url in urls:
|
||||
routes.append(router.add_route(method, url, handler))
|
||||
|
||||
# Use `get` because CORS middleware is not be loaded in emulated_hue
|
||||
if self.cors_allowed:
|
||||
allow_cors = app.get("allow_all_cors")
|
||||
else:
|
||||
allow_cors = app.get("allow_configured_cors")
|
||||
|
||||
if allow_cors:
|
||||
for route in routes:
|
||||
allow_cors(route)
|
39
tests/test_circular_imports.py
Normal file
39
tests/test_circular_imports.py
Normal file
@ -0,0 +1,39 @@
|
||||
"""Test to check for circular imports in core components."""
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from homeassistant.bootstrap import (
|
||||
CORE_INTEGRATIONS,
|
||||
DEBUGGER_INTEGRATIONS,
|
||||
DEFAULT_INTEGRATIONS,
|
||||
FRONTEND_INTEGRATIONS,
|
||||
LOGGING_INTEGRATIONS,
|
||||
RECORDER_INTEGRATIONS,
|
||||
STAGE_1_INTEGRATIONS,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.timeout(30) # cloud can take > 9s
|
||||
@pytest.mark.parametrize(
|
||||
"component",
|
||||
sorted(
|
||||
{
|
||||
*DEBUGGER_INTEGRATIONS,
|
||||
*CORE_INTEGRATIONS,
|
||||
*LOGGING_INTEGRATIONS,
|
||||
*FRONTEND_INTEGRATIONS,
|
||||
*RECORDER_INTEGRATIONS,
|
||||
*STAGE_1_INTEGRATIONS,
|
||||
*DEFAULT_INTEGRATIONS,
|
||||
}
|
||||
),
|
||||
)
|
||||
async def test_circular_imports(component: str) -> None:
|
||||
"""Check that components can be imported without circular imports."""
|
||||
process = await asyncio.create_subprocess_exec(
|
||||
sys.executable, "-c", f"import homeassistant.components.{component}"
|
||||
)
|
||||
await process.communicate()
|
||||
assert process.returncode == 0
|
Loading…
x
Reference in New Issue
Block a user