Change KEY_HASS to be an aiohttp AppKey (#111954)

This commit is contained in:
Marc Mueller 2024-03-07 13:37:48 +01:00 committed by GitHub
parent 82efb3d35b
commit 531e25cbc6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 39 additions and 25 deletions

View File

@ -143,6 +143,7 @@ from homeassistant.auth.models import (
User, User,
) )
from homeassistant.components import websocket_api from homeassistant.components import websocket_api
from homeassistant.components.http import KEY_HASS
from homeassistant.components.http.auth import ( from homeassistant.components.http.auth import (
async_sign_path, async_sign_path,
async_user_not_allowed_do_auth, async_user_not_allowed_do_auth,
@ -209,7 +210,7 @@ class RevokeTokenView(HomeAssistantView):
async def post(self, request: web.Request) -> web.Response: async def post(self, request: web.Request) -> web.Response:
"""Revoke a token.""" """Revoke a token."""
hass: HomeAssistant = request.app["hass"] hass = request.app[KEY_HASS]
data = cast(MultiDictProxy[str], await request.post()) data = cast(MultiDictProxy[str], await request.post())
# OAuth 2.0 Token Revocation [RFC7009] # OAuth 2.0 Token Revocation [RFC7009]
@ -243,7 +244,7 @@ class TokenView(HomeAssistantView):
@log_invalid_auth @log_invalid_auth
async def post(self, request: web.Request) -> web.Response: async def post(self, request: web.Request) -> web.Response:
"""Grant a token.""" """Grant a token."""
hass: HomeAssistant = request.app["hass"] hass = request.app[KEY_HASS]
data = cast(MultiDictProxy[str], await request.post()) data = cast(MultiDictProxy[str], await request.post())
grant_type = data.get("grant_type") grant_type = data.get("grant_type")
@ -415,7 +416,7 @@ class LinkUserView(HomeAssistantView):
@RequestDataValidator(vol.Schema({"code": str, "client_id": str})) @RequestDataValidator(vol.Schema({"code": str, "client_id": str}))
async def post(self, request: web.Request, data: dict[str, Any]) -> web.Response: async def post(self, request: web.Request, data: dict[str, Any]) -> web.Response:
"""Link a user.""" """Link a user."""
hass: HomeAssistant = request.app["hass"] hass = request.app[KEY_HASS]
user: User = request["hass_user"] user: User = request["hass_user"]
credentials = self._retrieve_credentials(data["client_id"], data["code"]) credentials = self._retrieve_credentials(data["client_id"], data["code"])

View File

@ -81,6 +81,7 @@ from homeassistant import data_entry_flow
from homeassistant.auth import AuthManagerFlowManager, InvalidAuthError from homeassistant.auth import AuthManagerFlowManager, InvalidAuthError
from homeassistant.auth.models import AuthFlowResult, Credentials from homeassistant.auth.models import AuthFlowResult, Credentials
from homeassistant.components import onboarding from homeassistant.components import onboarding
from homeassistant.components.http import KEY_HASS
from homeassistant.components.http.auth import async_user_not_allowed_do_auth from homeassistant.components.http.auth import async_user_not_allowed_do_auth
from homeassistant.components.http.ban import ( from homeassistant.components.http.ban import (
log_invalid_auth, log_invalid_auth,
@ -144,7 +145,7 @@ class AuthProvidersView(HomeAssistantView):
async def get(self, request: web.Request) -> web.Response: async def get(self, request: web.Request) -> web.Response:
"""Get available auth providers.""" """Get available auth providers."""
hass: HomeAssistant = request.app["hass"] hass = request.app[KEY_HASS]
if not onboarding.async_is_user_onboarded(hass): if not onboarding.async_is_user_onboarded(hass):
return self.json_message( return self.json_message(
message="Onboarding not finished", message="Onboarding not finished",
@ -255,7 +256,7 @@ class LoginFlowBaseView(HomeAssistantView):
await process_wrong_login(request) await process_wrong_login(request)
return self.json(_prepare_result_json(result)) return self.json(_prepare_result_json(result))
hass: HomeAssistant = request.app["hass"] hass = request.app[KEY_HASS]
if not await indieauth.verify_redirect_uri( if not await indieauth.verify_redirect_uri(
hass, client_id, result["context"]["redirect_uri"] hass, client_id, result["context"]["redirect_uri"]

View File

@ -34,6 +34,7 @@ from homeassistant.helpers import storage
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.http import ( from homeassistant.helpers.http import (
KEY_AUTHENTICATED, # noqa: F401 KEY_AUTHENTICATED, # noqa: F401
KEY_HASS,
HomeAssistantView, HomeAssistantView,
current_request, current_request,
) )
@ -47,7 +48,7 @@ from homeassistant.util.json import json_loads
from .auth import async_setup_auth from .auth import async_setup_auth
from .ban import setup_bans from .ban import setup_bans
from .const import KEY_HASS, KEY_HASS_REFRESH_TOKEN_ID, KEY_HASS_USER # noqa: F401 from .const import KEY_HASS_REFRESH_TOKEN_ID, KEY_HASS_USER # noqa: F401
from .cors import setup_cors from .cors import setup_cors
from .decorators import require_admin # noqa: F401 from .decorators import require_admin # noqa: F401
from .forwarded import async_setup_forwarded from .forwarded import async_setup_forwarded
@ -323,6 +324,7 @@ class HomeAssistantHTTP:
) -> None: ) -> None:
"""Initialize the server.""" """Initialize the server."""
self.app[KEY_HASS] = self.hass self.app[KEY_HASS] = self.hass
self.app["hass"] = self.hass # For backwards compatibility
# Order matters, security filters middleware needs to go first, # Order matters, security filters middleware needs to go first,
# forwarded middleware needs to go second. # forwarded middleware needs to go second.

View File

@ -21,6 +21,7 @@ from homeassistant.exceptions import HomeAssistantError
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.util import dt as dt_util, yaml from homeassistant.util import dt as dt_util, yaml
from .const import KEY_HASS
from .view import HomeAssistantView from .view import HomeAssistantView
_HassViewT = TypeVar("_HassViewT", bound=HomeAssistantView) _HassViewT = TypeVar("_HassViewT", bound=HomeAssistantView)
@ -105,7 +106,7 @@ async def process_wrong_login(request: Request) -> None:
Increase failed login attempts counter for remote IP address. Increase failed login attempts counter for remote IP address.
Add ip ban entry if failed login attempts exceeds threshold. Add ip ban entry if failed login attempts exceeds threshold.
""" """
hass = request.app["hass"] hass = request.app[KEY_HASS]
remote_addr = ip_address(request.remote) # type: ignore[arg-type] remote_addr = ip_address(request.remote) # type: ignore[arg-type]
remote_host = request.remote remote_host = request.remote

View File

@ -1,8 +1,7 @@
"""HTTP specific constants.""" """HTTP specific constants."""
from typing import Final from typing import Final
from homeassistant.helpers.http import KEY_AUTHENTICATED # noqa: F401 from homeassistant.helpers.http import KEY_AUTHENTICATED, KEY_HASS # noqa: F401
KEY_HASS: Final = "hass"
KEY_HASS_USER: Final = "hass_user" KEY_HASS_USER: Final = "hass_user"
KEY_HASS_REFRESH_TOKEN_ID: Final = "hass_refresh_token_id" KEY_HASS_REFRESH_TOKEN_ID: Final = "hass_refresh_token_id"

View File

@ -12,8 +12,6 @@ from aiohttp.web_exceptions import HTTPForbidden, HTTPNotFound
from aiohttp.web_urldispatcher import StaticResource from aiohttp.web_urldispatcher import StaticResource
from lru import LRU from lru import LRU
from homeassistant.core import HomeAssistant
from .const import KEY_HASS from .const import KEY_HASS
CACHE_TIME: Final = 31 * 86400 # = 1 month CACHE_TIME: Final = 31 * 86400 # = 1 month
@ -48,7 +46,7 @@ class CachingStaticResource(StaticResource):
rel_url = request.match_info["filename"] rel_url = request.match_info["filename"]
key = (rel_url, self._directory) key = (rel_url, self._directory)
if (filepath_content_type := PATH_CACHE.get(key)) is None: if (filepath_content_type := PATH_CACHE.get(key)) is None:
hass: HomeAssistant = request.app[KEY_HASS] hass = request.app[KEY_HASS]
try: try:
filepath = await hass.async_add_executor_job(_get_file_path, *key) filepath = await hass.async_add_executor_job(_get_file_path, *key)
except (ValueError, FileNotFoundError) as error: except (ValueError, FileNotFoundError) as error:

View File

@ -10,7 +10,7 @@ from typing import Any, Final
from aiohttp import web from aiohttp import web
from aiohttp.typedefs import LooseHeaders from aiohttp.typedefs import LooseHeaders
from aiohttp.web import Request from aiohttp.web import AppKey, Request
from aiohttp.web_exceptions import ( from aiohttp.web_exceptions import (
HTTPBadRequest, HTTPBadRequest,
HTTPInternalServerError, HTTPInternalServerError,
@ -30,6 +30,7 @@ _LOGGER = logging.getLogger(__name__)
KEY_AUTHENTICATED: Final = "ha_authenticated" KEY_AUTHENTICATED: Final = "ha_authenticated"
KEY_HASS: AppKey[HomeAssistant] = AppKey("hass")
current_request: ContextVar[Request | None] = ContextVar( current_request: ContextVar[Request | None] = ContextVar(
"current_request", default=None "current_request", default=None

View File

@ -17,6 +17,7 @@ from homeassistant.auth.providers.legacy_api_password import (
LegacyApiPasswordAuthProvider, LegacyApiPasswordAuthProvider,
) )
from homeassistant.components import websocket_api from homeassistant.components import websocket_api
from homeassistant.components.http import KEY_HASS
from homeassistant.components.http.auth import ( from homeassistant.components.http.auth import (
CONTENT_USER_NAME, CONTENT_USER_NAME,
DATA_SIGN_SECRET, DATA_SIGN_SECRET,
@ -78,7 +79,7 @@ async def get_legacy_user(auth):
def app(hass): def app(hass):
"""Fixture to set up a web.Application.""" """Fixture to set up a web.Application."""
app = web.Application() app = web.Application()
app["hass"] = hass app[KEY_HASS] = hass
app.router.add_get("/", mock_handler) app.router.add_get("/", mock_handler)
async_setup_forwarded(app, True, []) async_setup_forwarded(app, True, [])
return app return app
@ -88,7 +89,7 @@ def app(hass):
def app2(hass): def app2(hass):
"""Fixture to set up a web.Application without real_ip middleware.""" """Fixture to set up a web.Application without real_ip middleware."""
app = web.Application() app = web.Application()
app["hass"] = hass app[KEY_HASS] = hass
app.router.add_get("/", mock_handler) app.router.add_get("/", mock_handler)
return app return app

View File

@ -10,7 +10,7 @@ from aiohttp.web_middlewares import middleware
import pytest import pytest
import homeassistant.components.http as http import homeassistant.components.http as http
from homeassistant.components.http import KEY_AUTHENTICATED from homeassistant.components.http import KEY_AUTHENTICATED, KEY_HASS
from homeassistant.components.http.ban import ( from homeassistant.components.http.ban import (
IP_BANS_FILE, IP_BANS_FILE,
KEY_BAN_MANAGER, KEY_BAN_MANAGER,
@ -58,7 +58,7 @@ async def test_access_from_banned_ip(
) -> None: ) -> None:
"""Test accessing to server from banned IP. Both trusted and not.""" """Test accessing to server from banned IP. Both trusted and not."""
app = web.Application() app = web.Application()
app["hass"] = hass app[KEY_HASS] = hass
setup_bans(hass, app, 5) setup_bans(hass, app, 5)
set_real_ip = mock_real_ip(app) set_real_ip = mock_real_ip(app)
@ -87,7 +87,7 @@ async def test_access_from_banned_ip_with_partially_broken_yaml_file(
still load the bans. still load the bans.
""" """
app = web.Application() app = web.Application()
app["hass"] = hass app[KEY_HASS] = hass
setup_bans(hass, app, 5) setup_bans(hass, app, 5)
set_real_ip = mock_real_ip(app) set_real_ip = mock_real_ip(app)
@ -118,7 +118,7 @@ async def test_no_ip_bans_file(
) -> None: ) -> None:
"""Test no ip bans file.""" """Test no ip bans file."""
app = web.Application() app = web.Application()
app["hass"] = hass app[KEY_HASS] = hass
setup_bans(hass, app, 5) setup_bans(hass, app, 5)
set_real_ip = mock_real_ip(app) set_real_ip = mock_real_ip(app)
@ -138,7 +138,7 @@ async def test_failure_loading_ip_bans_file(
) -> None: ) -> None:
"""Test failure loading ip bans file.""" """Test failure loading ip bans file."""
app = web.Application() app = web.Application()
app["hass"] = hass app[KEY_HASS] = hass
setup_bans(hass, app, 5) setup_bans(hass, app, 5)
set_real_ip = mock_real_ip(app) set_real_ip = mock_real_ip(app)
@ -160,7 +160,7 @@ async def test_ip_ban_manager_never_started(
) -> None: ) -> None:
"""Test we handle the ip ban manager not being started.""" """Test we handle the ip ban manager not being started."""
app = web.Application() app = web.Application()
app["hass"] = hass app[KEY_HASS] = hass
setup_bans(hass, app, 5) setup_bans(hass, app, 5)
set_real_ip = mock_real_ip(app) set_real_ip = mock_real_ip(app)
@ -199,7 +199,7 @@ async def test_access_from_supervisor_ip(
) -> None: ) -> None:
"""Test accessing to server from supervisor IP.""" """Test accessing to server from supervisor IP."""
app = web.Application() app = web.Application()
app["hass"] = hass app[KEY_HASS] = hass
async def unauth_handler(request): async def unauth_handler(request):
"""Return a mock web response.""" """Return a mock web response."""
@ -270,7 +270,7 @@ async def test_ip_bans_file_creation(
) -> None: ) -> None:
"""Testing if banned IP file created.""" """Testing if banned IP file created."""
app = web.Application() app = web.Application()
app["hass"] = hass app[KEY_HASS] = hass
async def unauth_handler(request): async def unauth_handler(request):
"""Return a mock web response.""" """Return a mock web response."""
@ -326,7 +326,7 @@ async def test_failed_login_attempts_counter(
) -> None: ) -> None:
"""Testing if failed login attempts counter increased.""" """Testing if failed login attempts counter increased."""
app = web.Application() app = web.Application()
app["hass"] = hass app[KEY_HASS] = hass
async def auth_handler(request): async def auth_handler(request):
"""Return 200 status code.""" """Return 200 status code."""
@ -398,7 +398,7 @@ async def test_single_ban_file_entry(
) -> None: ) -> None:
"""Test that only one item is added to ban file.""" """Test that only one item is added to ban file."""
app = web.Application() app = web.Application()
app["hass"] = hass app[KEY_HASS] = hass
async def unauth_handler(request): async def unauth_handler(request):
"""Return a mock web response.""" """Return a mock web response."""

View File

@ -14,6 +14,7 @@ from homeassistant.auth.providers.legacy_api_password import (
) )
import homeassistant.components.http as http import homeassistant.components.http as http
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers.http import KEY_HASS
from homeassistant.helpers.network import NoURLAvailableError from homeassistant.helpers.network import NoURLAvailableError
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
@ -97,6 +98,15 @@ async def test_registering_view_while_running(
hass.http.register_view(TestView) hass.http.register_view(TestView)
async def test_homeassistant_assigned_to_app(hass: HomeAssistant) -> None:
"""Test HomeAssistant instance is assigned to HomeAssistantApp."""
assert await async_setup_component(hass, "api", {"http": {}})
await hass.async_start()
assert hass.http.app[KEY_HASS] == hass
assert hass.http.app["hass"] == hass # For backwards compatibility
await hass.async_stop()
async def test_not_log_password( async def test_not_log_password(
hass: HomeAssistant, hass: HomeAssistant,
hass_client_no_auth: ClientSessionGenerator, hass_client_no_auth: ClientSessionGenerator,