mirror of
https://github.com/home-assistant/core.git
synced 2025-07-21 12:17:07 +00:00
Add local only users (#57598)
This commit is contained in:
parent
847b10fa65
commit
914f7f85ec
@ -214,11 +214,19 @@ class AuthManager:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
async def async_create_system_user(
|
async def async_create_system_user(
|
||||||
self, name: str, group_ids: list[str] | None = None
|
self,
|
||||||
|
name: str,
|
||||||
|
*,
|
||||||
|
group_ids: list[str] | None = None,
|
||||||
|
local_only: bool | None = None,
|
||||||
) -> models.User:
|
) -> models.User:
|
||||||
"""Create a system user."""
|
"""Create a system user."""
|
||||||
user = await self._store.async_create_user(
|
user = await self._store.async_create_user(
|
||||||
name=name, system_generated=True, is_active=True, group_ids=group_ids or []
|
name=name,
|
||||||
|
system_generated=True,
|
||||||
|
is_active=True,
|
||||||
|
group_ids=group_ids or [],
|
||||||
|
local_only=local_only,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.hass.bus.async_fire(EVENT_USER_ADDED, {"user_id": user.id})
|
self.hass.bus.async_fire(EVENT_USER_ADDED, {"user_id": user.id})
|
||||||
@ -226,13 +234,18 @@ class AuthManager:
|
|||||||
return user
|
return user
|
||||||
|
|
||||||
async def async_create_user(
|
async def async_create_user(
|
||||||
self, name: str, group_ids: list[str] | None = None
|
self,
|
||||||
|
name: str,
|
||||||
|
*,
|
||||||
|
group_ids: list[str] | None = None,
|
||||||
|
local_only: bool | None = None,
|
||||||
) -> models.User:
|
) -> models.User:
|
||||||
"""Create a user."""
|
"""Create a user."""
|
||||||
kwargs: dict[str, Any] = {
|
kwargs: dict[str, Any] = {
|
||||||
"name": name,
|
"name": name,
|
||||||
"is_active": True,
|
"is_active": True,
|
||||||
"group_ids": group_ids or [],
|
"group_ids": group_ids or [],
|
||||||
|
"local_only": local_only,
|
||||||
}
|
}
|
||||||
|
|
||||||
if await self._user_should_be_owner():
|
if await self._user_should_be_owner():
|
||||||
@ -304,13 +317,18 @@ class AuthManager:
|
|||||||
name: str | None = None,
|
name: str | None = None,
|
||||||
is_active: bool | None = None,
|
is_active: bool | None = None,
|
||||||
group_ids: list[str] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
|
local_only: bool | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Update a user."""
|
"""Update a user."""
|
||||||
kwargs: dict[str, Any] = {}
|
kwargs: dict[str, Any] = {}
|
||||||
if name is not None:
|
|
||||||
kwargs["name"] = name
|
for attr_name, value in (
|
||||||
if group_ids is not None:
|
("name", name),
|
||||||
kwargs["group_ids"] = group_ids
|
("group_ids", group_ids),
|
||||||
|
("local_only", local_only),
|
||||||
|
):
|
||||||
|
if value is not None:
|
||||||
|
kwargs[attr_name] = value
|
||||||
await self._store.async_update_user(user, **kwargs)
|
await self._store.async_update_user(user, **kwargs)
|
||||||
|
|
||||||
if is_active is not None:
|
if is_active is not None:
|
||||||
|
@ -86,6 +86,7 @@ class AuthStore:
|
|||||||
system_generated: bool | None = None,
|
system_generated: bool | None = None,
|
||||||
credentials: models.Credentials | None = None,
|
credentials: models.Credentials | None = None,
|
||||||
group_ids: list[str] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
|
local_only: bool | None = None,
|
||||||
) -> models.User:
|
) -> models.User:
|
||||||
"""Create a new user."""
|
"""Create a new user."""
|
||||||
if self._users is None:
|
if self._users is None:
|
||||||
@ -108,14 +109,14 @@ class AuthStore:
|
|||||||
"perm_lookup": self._perm_lookup,
|
"perm_lookup": self._perm_lookup,
|
||||||
}
|
}
|
||||||
|
|
||||||
if is_owner is not None:
|
for attr_name, value in (
|
||||||
kwargs["is_owner"] = is_owner
|
("is_owner", is_owner),
|
||||||
|
("is_active", is_active),
|
||||||
if is_active is not None:
|
("local_only", local_only),
|
||||||
kwargs["is_active"] = is_active
|
("system_generated", system_generated),
|
||||||
|
):
|
||||||
if system_generated is not None:
|
if value is not None:
|
||||||
kwargs["system_generated"] = system_generated
|
kwargs[attr_name] = value
|
||||||
|
|
||||||
new_user = models.User(**kwargs)
|
new_user = models.User(**kwargs)
|
||||||
|
|
||||||
@ -152,6 +153,7 @@ class AuthStore:
|
|||||||
name: str | None = None,
|
name: str | None = None,
|
||||||
is_active: bool | None = None,
|
is_active: bool | None = None,
|
||||||
group_ids: list[str] | None = None,
|
group_ids: list[str] | None = None,
|
||||||
|
local_only: bool | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Update a user."""
|
"""Update a user."""
|
||||||
assert self._groups is not None
|
assert self._groups is not None
|
||||||
@ -166,7 +168,11 @@ class AuthStore:
|
|||||||
user.groups = groups
|
user.groups = groups
|
||||||
user.invalidate_permission_cache()
|
user.invalidate_permission_cache()
|
||||||
|
|
||||||
for attr_name, value in (("name", name), ("is_active", is_active)):
|
for attr_name, value in (
|
||||||
|
("name", name),
|
||||||
|
("is_active", is_active),
|
||||||
|
("local_only", local_only),
|
||||||
|
):
|
||||||
if value is not None:
|
if value is not None:
|
||||||
setattr(user, attr_name, value)
|
setattr(user, attr_name, value)
|
||||||
|
|
||||||
@ -417,6 +423,8 @@ class AuthStore:
|
|||||||
is_active=user_dict["is_active"],
|
is_active=user_dict["is_active"],
|
||||||
system_generated=user_dict["system_generated"],
|
system_generated=user_dict["system_generated"],
|
||||||
perm_lookup=perm_lookup,
|
perm_lookup=perm_lookup,
|
||||||
|
# New in 2021.11
|
||||||
|
local_only=user_dict.get("local_only", False),
|
||||||
)
|
)
|
||||||
|
|
||||||
for cred_dict in data["credentials"]:
|
for cred_dict in data["credentials"]:
|
||||||
@ -502,6 +510,7 @@ class AuthStore:
|
|||||||
"is_active": user.is_active,
|
"is_active": user.is_active,
|
||||||
"name": user.name,
|
"name": user.name,
|
||||||
"system_generated": user.system_generated,
|
"system_generated": user.system_generated,
|
||||||
|
"local_only": user.local_only,
|
||||||
}
|
}
|
||||||
for user in self._users.values()
|
for user in self._users.values()
|
||||||
]
|
]
|
||||||
|
@ -39,6 +39,7 @@ class User:
|
|||||||
is_owner: bool = attr.ib(default=False)
|
is_owner: bool = attr.ib(default=False)
|
||||||
is_active: bool = attr.ib(default=False)
|
is_active: bool = attr.ib(default=False)
|
||||||
system_generated: bool = attr.ib(default=False)
|
system_generated: bool = attr.ib(default=False)
|
||||||
|
local_only: bool = attr.ib(default=False)
|
||||||
|
|
||||||
groups: list[Group] = attr.ib(factory=list, eq=False, order=False)
|
groups: list[Group] = attr.ib(factory=list, eq=False, order=False)
|
||||||
|
|
||||||
|
@ -177,7 +177,9 @@ async def _configure_almond_for_ha(
|
|||||||
user = await hass.auth.async_get_user(data["almond_user"])
|
user = await hass.auth.async_get_user(data["almond_user"])
|
||||||
|
|
||||||
if user is None:
|
if user is None:
|
||||||
user = await hass.auth.async_create_system_user("Almond", [GROUP_ID_ADMIN])
|
user = await hass.auth.async_create_system_user(
|
||||||
|
"Almond", group_ids=[GROUP_ID_ADMIN]
|
||||||
|
)
|
||||||
data["almond_user"] = user.id
|
data["almond_user"] = user.id
|
||||||
await store.async_save(data)
|
await store.async_save(data)
|
||||||
|
|
||||||
|
@ -126,7 +126,10 @@ import voluptuous as vol
|
|||||||
from homeassistant.auth import InvalidAuthError
|
from homeassistant.auth import InvalidAuthError
|
||||||
from homeassistant.auth.models import TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN, Credentials
|
from homeassistant.auth.models import TOKEN_TYPE_LONG_LIVED_ACCESS_TOKEN, Credentials
|
||||||
from homeassistant.components import websocket_api
|
from homeassistant.components import websocket_api
|
||||||
from homeassistant.components.http.auth import async_sign_path
|
from homeassistant.components.http.auth import (
|
||||||
|
async_sign_path,
|
||||||
|
async_user_not_allowed_do_auth,
|
||||||
|
)
|
||||||
from homeassistant.components.http.ban import log_invalid_auth
|
from homeassistant.components.http.ban import log_invalid_auth
|
||||||
from homeassistant.components.http.data_validator import RequestDataValidator
|
from homeassistant.components.http.data_validator import RequestDataValidator
|
||||||
from homeassistant.components.http.view import HomeAssistantView
|
from homeassistant.components.http.view import HomeAssistantView
|
||||||
@ -299,9 +302,12 @@ class TokenView(HomeAssistantView):
|
|||||||
|
|
||||||
user = await hass.auth.async_get_or_create_user(credential)
|
user = await hass.auth.async_get_or_create_user(credential)
|
||||||
|
|
||||||
if not user.is_active:
|
if user_access_error := async_user_not_allowed_do_auth(hass, user):
|
||||||
return self.json(
|
return self.json(
|
||||||
{"error": "access_denied", "error_description": "User is not active"},
|
{
|
||||||
|
"error": "access_denied",
|
||||||
|
"error_description": user_access_error,
|
||||||
|
},
|
||||||
status_code=HTTPStatus.FORBIDDEN,
|
status_code=HTTPStatus.FORBIDDEN,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -355,6 +361,17 @@ class TokenView(HomeAssistantView):
|
|||||||
{"error": "invalid_request"}, status_code=HTTPStatus.BAD_REQUEST
|
{"error": "invalid_request"}, status_code=HTTPStatus.BAD_REQUEST
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if user_access_error := async_user_not_allowed_do_auth(
|
||||||
|
hass, refresh_token.user
|
||||||
|
):
|
||||||
|
return self.json(
|
||||||
|
{
|
||||||
|
"error": "access_denied",
|
||||||
|
"error_description": user_access_error,
|
||||||
|
},
|
||||||
|
status_code=HTTPStatus.FORBIDDEN,
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
access_token = hass.auth.async_create_access_token(
|
access_token = hass.auth.async_create_access_token(
|
||||||
refresh_token, remote_addr
|
refresh_token, remote_addr
|
||||||
|
@ -74,6 +74,8 @@ import voluptuous as vol
|
|||||||
import voluptuous_serialize
|
import voluptuous_serialize
|
||||||
|
|
||||||
from homeassistant import data_entry_flow
|
from homeassistant import data_entry_flow
|
||||||
|
from homeassistant.auth.models import Credentials
|
||||||
|
from homeassistant.components.http.auth import async_user_not_allowed_do_auth
|
||||||
from homeassistant.components.http.ban import (
|
from homeassistant.components.http.ban import (
|
||||||
log_invalid_auth,
|
log_invalid_auth,
|
||||||
process_success_login,
|
process_success_login,
|
||||||
@ -81,6 +83,7 @@ from homeassistant.components.http.ban import (
|
|||||||
)
|
)
|
||||||
from homeassistant.components.http.data_validator import RequestDataValidator
|
from homeassistant.components.http.data_validator import RequestDataValidator
|
||||||
from homeassistant.components.http.view import HomeAssistantView
|
from homeassistant.components.http.view import HomeAssistantView
|
||||||
|
from homeassistant.core import HomeAssistant
|
||||||
|
|
||||||
from . import indieauth
|
from . import indieauth
|
||||||
|
|
||||||
@ -138,11 +141,9 @@ def _prepare_result_json(result):
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
class LoginFlowIndexView(HomeAssistantView):
|
class LoginFlowBaseView(HomeAssistantView):
|
||||||
"""View to create a config flow."""
|
"""Base class for the login views."""
|
||||||
|
|
||||||
url = "/auth/login_flow"
|
|
||||||
name = "api:auth:login_flow"
|
|
||||||
requires_auth = False
|
requires_auth = False
|
||||||
|
|
||||||
def __init__(self, flow_mgr, store_result):
|
def __init__(self, flow_mgr, store_result):
|
||||||
@ -150,6 +151,46 @@ class LoginFlowIndexView(HomeAssistantView):
|
|||||||
self._flow_mgr = flow_mgr
|
self._flow_mgr = flow_mgr
|
||||||
self._store_result = store_result
|
self._store_result = store_result
|
||||||
|
|
||||||
|
async def _async_flow_result_to_response(self, request, client_id, result):
|
||||||
|
"""Convert the flow result to a response."""
|
||||||
|
if result["type"] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
|
||||||
|
if result["type"] == data_entry_flow.RESULT_TYPE_FORM:
|
||||||
|
# @log_invalid_auth does not work here since it returns HTTP 200
|
||||||
|
# need manually log failed login attempts
|
||||||
|
if result.get("errors", {}).get("base") in (
|
||||||
|
"invalid_auth",
|
||||||
|
"invalid_code",
|
||||||
|
):
|
||||||
|
await process_wrong_login(request)
|
||||||
|
return self.json(_prepare_result_json(result))
|
||||||
|
|
||||||
|
result.pop("data")
|
||||||
|
|
||||||
|
hass: HomeAssistant = request.app["hass"]
|
||||||
|
result_obj: Credentials = result.pop("result")
|
||||||
|
|
||||||
|
# Result can be None if credential was never linked to a user before.
|
||||||
|
user = await hass.auth.async_get_user_by_credentials(result_obj)
|
||||||
|
|
||||||
|
if user is not None and (
|
||||||
|
user_access_error := async_user_not_allowed_do_auth(hass, user)
|
||||||
|
):
|
||||||
|
return self.json_message(
|
||||||
|
f"Login blocked: {user_access_error}", HTTPStatus.FORBIDDEN
|
||||||
|
)
|
||||||
|
|
||||||
|
await process_success_login(request)
|
||||||
|
result["result"] = self._store_result(client_id, result_obj)
|
||||||
|
|
||||||
|
return self.json(result)
|
||||||
|
|
||||||
|
|
||||||
|
class LoginFlowIndexView(LoginFlowBaseView):
|
||||||
|
"""View to create a config flow."""
|
||||||
|
|
||||||
|
url = "/auth/login_flow"
|
||||||
|
name = "api:auth:login_flow"
|
||||||
|
|
||||||
async def get(self, request):
|
async def get(self, request):
|
||||||
"""Do not allow index of flows in progress."""
|
"""Do not allow index of flows in progress."""
|
||||||
# pylint: disable=no-self-use
|
# pylint: disable=no-self-use
|
||||||
@ -195,26 +236,16 @@ class LoginFlowIndexView(HomeAssistantView):
|
|||||||
"Handler does not support init", HTTPStatus.BAD_REQUEST
|
"Handler does not support init", HTTPStatus.BAD_REQUEST
|
||||||
)
|
)
|
||||||
|
|
||||||
if result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
|
return await self._async_flow_result_to_response(
|
||||||
await process_success_login(request)
|
request, data["client_id"], result
|
||||||
result.pop("data")
|
)
|
||||||
result["result"] = self._store_result(data["client_id"], result["result"])
|
|
||||||
return self.json(result)
|
|
||||||
|
|
||||||
return self.json(_prepare_result_json(result))
|
|
||||||
|
|
||||||
|
|
||||||
class LoginFlowResourceView(HomeAssistantView):
|
class LoginFlowResourceView(LoginFlowBaseView):
|
||||||
"""View to interact with the flow manager."""
|
"""View to interact with the flow manager."""
|
||||||
|
|
||||||
url = "/auth/login_flow/{flow_id}"
|
url = "/auth/login_flow/{flow_id}"
|
||||||
name = "api:auth:login_flow:resource"
|
name = "api:auth:login_flow:resource"
|
||||||
requires_auth = False
|
|
||||||
|
|
||||||
def __init__(self, flow_mgr, store_result):
|
|
||||||
"""Initialize the login flow resource view."""
|
|
||||||
self._flow_mgr = flow_mgr
|
|
||||||
self._store_result = store_result
|
|
||||||
|
|
||||||
async def get(self, request):
|
async def get(self, request):
|
||||||
"""Do not allow getting status of a flow in progress."""
|
"""Do not allow getting status of a flow in progress."""
|
||||||
@ -240,20 +271,7 @@ class LoginFlowResourceView(HomeAssistantView):
|
|||||||
except vol.Invalid:
|
except vol.Invalid:
|
||||||
return self.json_message("User input malformed", HTTPStatus.BAD_REQUEST)
|
return self.json_message("User input malformed", HTTPStatus.BAD_REQUEST)
|
||||||
|
|
||||||
if result["type"] != data_entry_flow.RESULT_TYPE_CREATE_ENTRY:
|
return await self._async_flow_result_to_response(request, client_id, result)
|
||||||
# @log_invalid_auth does not work here since it returns HTTP 200
|
|
||||||
# need manually log failed login attempts
|
|
||||||
if result.get("errors") is not None and result["errors"].get("base") in (
|
|
||||||
"invalid_auth",
|
|
||||||
"invalid_code",
|
|
||||||
):
|
|
||||||
await process_wrong_login(request)
|
|
||||||
return self.json(_prepare_result_json(result))
|
|
||||||
|
|
||||||
result.pop("data")
|
|
||||||
result["result"] = self._store_result(client_id, result["result"])
|
|
||||||
|
|
||||||
return self.json(result)
|
|
||||||
|
|
||||||
async def delete(self, request, flow_id):
|
async def delete(self, request, flow_id):
|
||||||
"""Cancel a flow in progress."""
|
"""Cancel a flow in progress."""
|
||||||
|
@ -28,7 +28,7 @@ async def async_setup_ha_cast(
|
|||||||
|
|
||||||
if user is None:
|
if user is None:
|
||||||
user = await hass.auth.async_create_system_user(
|
user = await hass.auth.async_create_system_user(
|
||||||
"Home Assistant Cast", [auth.GROUP_ID_ADMIN]
|
"Home Assistant Cast", group_ids=[auth.GROUP_ID_ADMIN]
|
||||||
)
|
)
|
||||||
hass.config_entries.async_update_entry(
|
hass.config_entries.async_update_entry(
|
||||||
entry, data={**entry.data, "user_id": user.id}
|
entry, data={**entry.data, "user_id": user.id}
|
||||||
|
@ -281,7 +281,7 @@ class CloudPreferences:
|
|||||||
return user.id
|
return user.id
|
||||||
|
|
||||||
user = await self._hass.auth.async_create_system_user(
|
user = await self._hass.auth.async_create_system_user(
|
||||||
"Home Assistant Cloud", [GROUP_ID_ADMIN]
|
"Home Assistant Cloud", group_ids=[GROUP_ID_ADMIN], local_only=True
|
||||||
)
|
)
|
||||||
assert user is not None
|
assert user is not None
|
||||||
await self.async_update(cloud_user=user.id)
|
await self.async_update(cloud_user=user.id)
|
||||||
|
@ -66,11 +66,14 @@ async def websocket_delete(hass, connection, msg):
|
|||||||
vol.Required("type"): "config/auth/create",
|
vol.Required("type"): "config/auth/create",
|
||||||
vol.Required("name"): str,
|
vol.Required("name"): str,
|
||||||
vol.Optional("group_ids"): [str],
|
vol.Optional("group_ids"): [str],
|
||||||
|
vol.Optional("local_only"): bool,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
async def websocket_create(hass, connection, msg):
|
async def websocket_create(hass, connection, msg):
|
||||||
"""Create a user."""
|
"""Create a user."""
|
||||||
user = await hass.auth.async_create_user(msg["name"], msg.get("group_ids"))
|
user = await hass.auth.async_create_user(
|
||||||
|
msg["name"], group_ids=msg.get("group_ids"), local_only=msg.get("local_only")
|
||||||
|
)
|
||||||
|
|
||||||
connection.send_message(
|
connection.send_message(
|
||||||
websocket_api.result_message(msg["id"], {"user": _user_info(user)})
|
websocket_api.result_message(msg["id"], {"user": _user_info(user)})
|
||||||
@ -86,6 +89,7 @@ async def websocket_create(hass, connection, msg):
|
|||||||
vol.Optional("name"): str,
|
vol.Optional("name"): str,
|
||||||
vol.Optional("is_active"): bool,
|
vol.Optional("is_active"): bool,
|
||||||
vol.Optional("group_ids"): [str],
|
vol.Optional("group_ids"): [str],
|
||||||
|
vol.Optional("local_only"): bool,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
async def websocket_update(hass, connection, msg):
|
async def websocket_update(hass, connection, msg):
|
||||||
|
@ -442,7 +442,9 @@ async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: # noqa:
|
|||||||
await hass.auth.async_update_user(user, name="Supervisor")
|
await hass.auth.async_update_user(user, name="Supervisor")
|
||||||
|
|
||||||
if refresh_token is None:
|
if refresh_token is None:
|
||||||
user = await hass.auth.async_create_system_user("Supervisor", [GROUP_ID_ADMIN])
|
user = await hass.auth.async_create_system_user(
|
||||||
|
"Supervisor", group_ids=[GROUP_ID_ADMIN]
|
||||||
|
)
|
||||||
refresh_token = await hass.auth.async_create_refresh_token(user)
|
refresh_token = await hass.auth.async_create_refresh_token(user)
|
||||||
data["hassio_user"] = user.id
|
data["hassio_user"] = user.id
|
||||||
await store.async_save(data)
|
await store.async_save(data)
|
||||||
|
@ -1,7 +1,6 @@
|
|||||||
"""Support to serve the Home Assistant API as WSGI application."""
|
"""Support to serve the Home Assistant API as WSGI application."""
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from contextvars import ContextVar
|
|
||||||
from ipaddress import ip_network
|
from ipaddress import ip_network
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
@ -28,7 +27,7 @@ from .ban import setup_bans
|
|||||||
from .const import KEY_AUTHENTICATED, KEY_HASS, KEY_HASS_USER # noqa: F401
|
from .const import KEY_AUTHENTICATED, KEY_HASS, KEY_HASS_USER # noqa: F401
|
||||||
from .cors import setup_cors
|
from .cors import setup_cors
|
||||||
from .forwarded import async_setup_forwarded
|
from .forwarded import async_setup_forwarded
|
||||||
from .request_context import setup_request_context
|
from .request_context import current_request, setup_request_context
|
||||||
from .security_filter import setup_security_filter
|
from .security_filter import setup_security_filter
|
||||||
from .static import CACHE_HEADERS, CachingStaticResource
|
from .static import CACHE_HEADERS, CachingStaticResource
|
||||||
from .view import HomeAssistantView
|
from .view import HomeAssistantView
|
||||||
@ -401,8 +400,3 @@ async def start_http_server_and_save_config(
|
|||||||
]
|
]
|
||||||
|
|
||||||
store.async_delay_save(lambda: conf, SAVE_DELAY)
|
store.async_delay_save(lambda: conf, SAVE_DELAY)
|
||||||
|
|
||||||
|
|
||||||
current_request: ContextVar[web.Request | None] = ContextVar(
|
|
||||||
"current_request", default=None
|
|
||||||
)
|
|
||||||
|
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from collections.abc import Awaitable, Callable
|
from collections.abc import Awaitable, Callable
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
|
from ipaddress import ip_address
|
||||||
import logging
|
import logging
|
||||||
import secrets
|
import secrets
|
||||||
from typing import Final
|
from typing import Final
|
||||||
@ -12,10 +13,13 @@ from aiohttp import hdrs
|
|||||||
from aiohttp.web import Application, Request, StreamResponse, middleware
|
from aiohttp.web import Application, Request, StreamResponse, middleware
|
||||||
import jwt
|
import jwt
|
||||||
|
|
||||||
|
from homeassistant.auth.models import User
|
||||||
from homeassistant.core import HomeAssistant, callback
|
from homeassistant.core import HomeAssistant, callback
|
||||||
from homeassistant.util import dt as dt_util
|
from homeassistant.util import dt as dt_util
|
||||||
|
from homeassistant.util.network import is_local
|
||||||
|
|
||||||
from .const import KEY_AUTHENTICATED, KEY_HASS_REFRESH_TOKEN_ID, KEY_HASS_USER
|
from .const import KEY_AUTHENTICATED, KEY_HASS_REFRESH_TOKEN_ID, KEY_HASS_USER
|
||||||
|
from .request_context import current_request
|
||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
@ -46,6 +50,42 @@ def async_sign_path(
|
|||||||
return f"{path}?{SIGN_QUERY_PARAM}={encoded}"
|
return f"{path}?{SIGN_QUERY_PARAM}={encoded}"
|
||||||
|
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_user_not_allowed_do_auth(
|
||||||
|
hass: HomeAssistant, user: User, request: Request | None = None
|
||||||
|
) -> str | None:
|
||||||
|
"""Validate that user is not allowed to do auth things."""
|
||||||
|
if not user.is_active:
|
||||||
|
return "User is not active"
|
||||||
|
|
||||||
|
if not user.local_only:
|
||||||
|
return None
|
||||||
|
|
||||||
|
# User is marked as local only, check if they are allowed to do auth
|
||||||
|
if request is None:
|
||||||
|
request = current_request.get()
|
||||||
|
|
||||||
|
if not request:
|
||||||
|
return "No request available to validate local access"
|
||||||
|
|
||||||
|
if "cloud" in hass.config.components:
|
||||||
|
# pylint: disable=import-outside-toplevel
|
||||||
|
from hass_nabucasa import remote
|
||||||
|
|
||||||
|
if remote.is_cloud_request.get():
|
||||||
|
return "User is local only"
|
||||||
|
|
||||||
|
try:
|
||||||
|
remote = ip_address(request.remote)
|
||||||
|
except ValueError:
|
||||||
|
return "Invalid remote IP"
|
||||||
|
|
||||||
|
if is_local(remote):
|
||||||
|
return None
|
||||||
|
|
||||||
|
return "User cannot authenticate remotely"
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def setup_auth(hass: HomeAssistant, app: Application) -> None:
|
def setup_auth(hass: HomeAssistant, app: Application) -> None:
|
||||||
"""Create auth middleware for the app."""
|
"""Create auth middleware for the app."""
|
||||||
@ -72,6 +112,9 @@ def setup_auth(hass: HomeAssistant, app: Application) -> None:
|
|||||||
if refresh_token is None:
|
if refresh_token is None:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
if async_user_not_allowed_do_auth(hass, refresh_token.user, request):
|
||||||
|
return False
|
||||||
|
|
||||||
request[KEY_HASS_USER] = refresh_token.user
|
request[KEY_HASS_USER] = refresh_token.user
|
||||||
request[KEY_HASS_REFRESH_TOKEN_ID] = refresh_token.id
|
request[KEY_HASS_REFRESH_TOKEN_ID] = refresh_token.id
|
||||||
return True
|
return True
|
||||||
|
@ -8,6 +8,10 @@ from aiohttp.web import Application, Request, StreamResponse, middleware
|
|||||||
|
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import callback
|
||||||
|
|
||||||
|
current_request: ContextVar[Request | None] = ContextVar(
|
||||||
|
"current_request", default=None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@callback
|
@callback
|
||||||
def setup_request_context(
|
def setup_request_context(
|
||||||
|
@ -129,7 +129,9 @@ class UserOnboardingView(_BaseOnboardingView):
|
|||||||
provider = _async_get_hass_provider(hass)
|
provider = _async_get_hass_provider(hass)
|
||||||
await provider.async_initialize()
|
await provider.async_initialize()
|
||||||
|
|
||||||
user = await hass.auth.async_create_user(data["name"], [GROUP_ID_ADMIN])
|
user = await hass.auth.async_create_user(
|
||||||
|
data["name"], group_ids=[GROUP_ID_ADMIN]
|
||||||
|
)
|
||||||
await hass.async_add_executor_job(
|
await hass.async_add_executor_job(
|
||||||
provider.data.add_auth, data["username"], data["password"]
|
provider.data.add_auth, data["username"], data["password"]
|
||||||
)
|
)
|
||||||
|
@ -13,7 +13,7 @@ from homeassistant.auth import (
|
|||||||
const as auth_const,
|
const as auth_const,
|
||||||
models as auth_models,
|
models as auth_models,
|
||||||
)
|
)
|
||||||
from homeassistant.auth.const import MFA_SESSION_EXPIRATION
|
from homeassistant.auth.const import GROUP_ID_ADMIN, MFA_SESSION_EXPIRATION
|
||||||
from homeassistant.core import callback
|
from homeassistant.core import callback
|
||||||
from homeassistant.util import dt as dt_util
|
from homeassistant.util import dt as dt_util
|
||||||
|
|
||||||
@ -390,6 +390,8 @@ async def test_generating_system_user(hass):
|
|||||||
user = await manager.async_create_system_user("Hass.io")
|
user = await manager.async_create_system_user("Hass.io")
|
||||||
token = await manager.async_create_refresh_token(user)
|
token = await manager.async_create_refresh_token(user)
|
||||||
assert user.system_generated
|
assert user.system_generated
|
||||||
|
assert user.groups == []
|
||||||
|
assert not user.local_only
|
||||||
assert token is not None
|
assert token is not None
|
||||||
assert token.client_id is None
|
assert token.client_id is None
|
||||||
|
|
||||||
@ -397,6 +399,21 @@ async def test_generating_system_user(hass):
|
|||||||
assert len(events) == 1
|
assert len(events) == 1
|
||||||
assert events[0].data["user_id"] == user.id
|
assert events[0].data["user_id"] == user.id
|
||||||
|
|
||||||
|
# Passing arguments
|
||||||
|
user = await manager.async_create_system_user(
|
||||||
|
"Hass.io", group_ids=[GROUP_ID_ADMIN], local_only=True
|
||||||
|
)
|
||||||
|
token = await manager.async_create_refresh_token(user)
|
||||||
|
assert user.system_generated
|
||||||
|
assert user.is_admin
|
||||||
|
assert user.local_only
|
||||||
|
assert token is not None
|
||||||
|
assert token.client_id is None
|
||||||
|
|
||||||
|
await hass.async_block_till_done()
|
||||||
|
assert len(events) == 2
|
||||||
|
assert events[1].data["user_id"] == user.id
|
||||||
|
|
||||||
|
|
||||||
async def test_refresh_token_requires_client_for_user(hass):
|
async def test_refresh_token_requires_client_for_user(hass):
|
||||||
"""Test create refresh token for a user with client_id."""
|
"""Test create refresh token for a user with client_id."""
|
||||||
@ -1038,15 +1055,19 @@ async def test_new_users(mock_hass):
|
|||||||
# first user in the system is owner and admin
|
# first user in the system is owner and admin
|
||||||
assert user.is_owner
|
assert user.is_owner
|
||||||
assert user.is_admin
|
assert user.is_admin
|
||||||
|
assert not user.local_only
|
||||||
assert user.groups == []
|
assert user.groups == []
|
||||||
|
|
||||||
user = await manager.async_create_user("Hello 2")
|
user = await manager.async_create_user("Hello 2")
|
||||||
assert not user.is_admin
|
assert not user.is_admin
|
||||||
assert user.groups == []
|
assert user.groups == []
|
||||||
|
|
||||||
user = await manager.async_create_user("Hello 3", ["system-admin"])
|
user = await manager.async_create_user(
|
||||||
|
"Hello 3", group_ids=["system-admin"], local_only=True
|
||||||
|
)
|
||||||
assert user.is_admin
|
assert user.is_admin
|
||||||
assert user.groups[0].id == "system-admin"
|
assert user.groups[0].id == "system-admin"
|
||||||
|
assert user.local_only
|
||||||
|
|
||||||
user_cred = await manager.async_get_or_create_user(
|
user_cred = await manager.async_get_or_create_user(
|
||||||
auth_models.Credentials(
|
auth_models.Credentials(
|
||||||
|
@ -109,6 +109,48 @@ async def test_login_new_user_and_trying_refresh_token(hass, aiohttp_client):
|
|||||||
assert resp.status == HTTPStatus.OK
|
assert resp.status == HTTPStatus.OK
|
||||||
|
|
||||||
|
|
||||||
|
async def test_auth_code_checks_local_only_user(hass, aiohttp_client):
|
||||||
|
"""Test local only user cannot exchange auth code for refresh tokens when external."""
|
||||||
|
client = await async_setup_auth(hass, aiohttp_client, setup_api=True)
|
||||||
|
resp = await client.post(
|
||||||
|
"/auth/login_flow",
|
||||||
|
json={
|
||||||
|
"client_id": CLIENT_ID,
|
||||||
|
"handler": ["insecure_example", None],
|
||||||
|
"redirect_uri": CLIENT_REDIRECT_URI,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert resp.status == HTTPStatus.OK
|
||||||
|
step = await resp.json()
|
||||||
|
|
||||||
|
resp = await client.post(
|
||||||
|
f"/auth/login_flow/{step['flow_id']}",
|
||||||
|
json={"client_id": CLIENT_ID, "username": "test-user", "password": "test-pass"},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status == HTTPStatus.OK
|
||||||
|
step = await resp.json()
|
||||||
|
code = step["result"]
|
||||||
|
|
||||||
|
# Exchange code for tokens
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.auth.async_user_not_allowed_do_auth",
|
||||||
|
return_value="User is local only",
|
||||||
|
):
|
||||||
|
resp = await client.post(
|
||||||
|
"/auth/token",
|
||||||
|
data={
|
||||||
|
"client_id": CLIENT_ID,
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
"code": code,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status == HTTPStatus.FORBIDDEN
|
||||||
|
error = await resp.json()
|
||||||
|
assert error["error"] == "access_denied"
|
||||||
|
|
||||||
|
|
||||||
def test_auth_code_store_expiration(mock_credential):
|
def test_auth_code_store_expiration(mock_credential):
|
||||||
"""Test that the auth code store will not return expired tokens."""
|
"""Test that the auth code store will not return expired tokens."""
|
||||||
store, retrieve = auth._create_auth_code_store()
|
store, retrieve = auth._create_auth_code_store()
|
||||||
@ -264,6 +306,30 @@ async def test_refresh_token_different_client_id(hass, aiohttp_client):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_refresh_token_checks_local_only_user(hass, aiohttp_client):
|
||||||
|
"""Test that we can't refresh token for a local only user when external."""
|
||||||
|
client = await async_setup_auth(hass, aiohttp_client)
|
||||||
|
refresh_token = await async_setup_user_refresh_token(hass)
|
||||||
|
refresh_token.user.local_only = True
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.auth.async_user_not_allowed_do_auth",
|
||||||
|
return_value="User is local only",
|
||||||
|
):
|
||||||
|
resp = await client.post(
|
||||||
|
"/auth/token",
|
||||||
|
data={
|
||||||
|
"client_id": CLIENT_ID,
|
||||||
|
"grant_type": "refresh_token",
|
||||||
|
"refresh_token": refresh_token.token,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status == HTTPStatus.FORBIDDEN
|
||||||
|
result = await resp.json()
|
||||||
|
assert result["error"] == "access_denied"
|
||||||
|
|
||||||
|
|
||||||
async def test_refresh_token_provider_rejected(
|
async def test_refresh_token_provider_rejected(
|
||||||
hass, aiohttp_client, hass_admin_user, hass_admin_credential
|
hass, aiohttp_client, hass_admin_user, hass_admin_credential
|
||||||
):
|
):
|
||||||
|
@ -116,6 +116,44 @@ async def test_login_exist_user(hass, aiohttp_client):
|
|||||||
assert len(step["result"]) > 1
|
assert len(step["result"]) > 1
|
||||||
|
|
||||||
|
|
||||||
|
async def test_login_local_only_user(hass, aiohttp_client):
|
||||||
|
"""Test logging in with local only user."""
|
||||||
|
client = await async_setup_auth(hass, aiohttp_client, setup_api=True)
|
||||||
|
cred = await hass.auth.auth_providers[0].async_get_or_create_credentials(
|
||||||
|
{"username": "test-user"}
|
||||||
|
)
|
||||||
|
user = await hass.auth.async_get_or_create_user(cred)
|
||||||
|
await hass.auth.async_update_user(user, local_only=True)
|
||||||
|
|
||||||
|
resp = await client.post(
|
||||||
|
"/auth/login_flow",
|
||||||
|
json={
|
||||||
|
"client_id": CLIENT_ID,
|
||||||
|
"handler": ["insecure_example", None],
|
||||||
|
"redirect_uri": CLIENT_REDIRECT_URI,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert resp.status == HTTPStatus.OK
|
||||||
|
step = await resp.json()
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"homeassistant.components.auth.login_flow.async_user_not_allowed_do_auth",
|
||||||
|
return_value="User is local only",
|
||||||
|
) as mock_not_allowed_do_auth:
|
||||||
|
resp = await client.post(
|
||||||
|
f"/auth/login_flow/{step['flow_id']}",
|
||||||
|
json={
|
||||||
|
"client_id": CLIENT_ID,
|
||||||
|
"username": "test-user",
|
||||||
|
"password": "test-pass",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(mock_not_allowed_do_auth.mock_calls) == 1
|
||||||
|
assert resp.status == HTTPStatus.FORBIDDEN
|
||||||
|
assert await resp.json() == {"message": "Login blocked: User is local only"}
|
||||||
|
|
||||||
|
|
||||||
async def test_login_exist_user_ip_changes(hass, aiohttp_client):
|
async def test_login_exist_user_ip_changes(hass, aiohttp_client):
|
||||||
"""Test logging in and the ip address changes results in an rejection."""
|
"""Test logging in and the ip address changes results in an rejection."""
|
||||||
client = await async_setup_auth(hass, aiohttp_client, setup_api=True)
|
client = await async_setup_auth(hass, aiohttp_client, setup_api=True)
|
||||||
|
@ -2,14 +2,18 @@
|
|||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from http import HTTPStatus
|
from http import HTTPStatus
|
||||||
from ipaddress import ip_network
|
from ipaddress import ip_network
|
||||||
from unittest.mock import patch
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
from aiohttp import BasicAuth, web
|
from aiohttp import BasicAuth, web
|
||||||
from aiohttp.web_exceptions import HTTPUnauthorized
|
from aiohttp.web_exceptions import HTTPUnauthorized
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.auth.providers import trusted_networks
|
from homeassistant.auth.providers import trusted_networks
|
||||||
from homeassistant.components.http.auth import async_sign_path, setup_auth
|
from homeassistant.components.http.auth import (
|
||||||
|
async_sign_path,
|
||||||
|
async_user_not_allowed_do_auth,
|
||||||
|
setup_auth,
|
||||||
|
)
|
||||||
from homeassistant.components.http.const import KEY_AUTHENTICATED
|
from homeassistant.components.http.const import KEY_AUTHENTICATED
|
||||||
from homeassistant.components.http.forwarded import async_setup_forwarded
|
from homeassistant.components.http.forwarded import async_setup_forwarded
|
||||||
from homeassistant.setup import async_setup_component
|
from homeassistant.setup import async_setup_component
|
||||||
@ -26,7 +30,8 @@ TRUSTED_NETWORKS = [
|
|||||||
ip_network("FD01:DB8::1"),
|
ip_network("FD01:DB8::1"),
|
||||||
]
|
]
|
||||||
TRUSTED_ADDRESSES = ["100.64.0.1", "192.0.2.100", "FD01:DB8::1", "2001:DB8:ABCD::1"]
|
TRUSTED_ADDRESSES = ["100.64.0.1", "192.0.2.100", "FD01:DB8::1", "2001:DB8:ABCD::1"]
|
||||||
UNTRUSTED_ADDRESSES = ["198.51.100.1", "2001:DB8:FA1::1", "127.0.0.1", "::1"]
|
EXTERNAL_ADDRESSES = ["198.51.100.1", "2001:DB8:FA1::1"]
|
||||||
|
UNTRUSTED_ADDRESSES = [*EXTERNAL_ADDRESSES, "127.0.0.1", "::1"]
|
||||||
|
|
||||||
|
|
||||||
async def mock_handler(request):
|
async def mock_handler(request):
|
||||||
@ -270,3 +275,68 @@ async def test_auth_access_signed_path(hass, app, aiohttp_client, hass_access_to
|
|||||||
await hass.auth.async_remove_refresh_token(refresh_token)
|
await hass.auth.async_remove_refresh_token(refresh_token)
|
||||||
req = await client.get(signed_path)
|
req = await client.get(signed_path)
|
||||||
assert req.status == HTTPStatus.UNAUTHORIZED
|
assert req.status == HTTPStatus.UNAUTHORIZED
|
||||||
|
|
||||||
|
|
||||||
|
async def test_local_only_user_rejected(hass, app, aiohttp_client, hass_access_token):
|
||||||
|
"""Test access with access token in header."""
|
||||||
|
token = hass_access_token
|
||||||
|
setup_auth(hass, app)
|
||||||
|
set_mock_ip = mock_real_ip(app)
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
refresh_token = await hass.auth.async_validate_access_token(hass_access_token)
|
||||||
|
|
||||||
|
req = await client.get("/", headers={"Authorization": f"Bearer {token}"})
|
||||||
|
assert req.status == HTTPStatus.OK
|
||||||
|
assert await req.json() == {"user_id": refresh_token.user.id}
|
||||||
|
|
||||||
|
refresh_token.user.local_only = True
|
||||||
|
|
||||||
|
for remote_addr in EXTERNAL_ADDRESSES:
|
||||||
|
set_mock_ip(remote_addr)
|
||||||
|
req = await client.get("/", headers={"Authorization": f"Bearer {token}"})
|
||||||
|
assert req.status == HTTPStatus.UNAUTHORIZED
|
||||||
|
|
||||||
|
|
||||||
|
async def test_async_user_not_allowed_do_auth(hass, app):
|
||||||
|
"""Test for not allowing auth."""
|
||||||
|
user = await hass.auth.async_create_user("Hello")
|
||||||
|
user.is_active = False
|
||||||
|
|
||||||
|
# User not active
|
||||||
|
assert async_user_not_allowed_do_auth(hass, user) == "User is not active"
|
||||||
|
|
||||||
|
user.is_active = True
|
||||||
|
user.local_only = True
|
||||||
|
|
||||||
|
# No current request
|
||||||
|
assert (
|
||||||
|
async_user_not_allowed_do_auth(hass, user)
|
||||||
|
== "No request available to validate local access"
|
||||||
|
)
|
||||||
|
|
||||||
|
trusted_request = Mock(remote="192.168.1.123")
|
||||||
|
untrusted_request = Mock(remote=UNTRUSTED_ADDRESSES[0])
|
||||||
|
|
||||||
|
# Is Remote IP and local only (cloud not loaded)
|
||||||
|
assert async_user_not_allowed_do_auth(hass, user, trusted_request) is None
|
||||||
|
assert (
|
||||||
|
async_user_not_allowed_do_auth(hass, user, untrusted_request)
|
||||||
|
== "User cannot authenticate remotely"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mimic cloud loaded and validate local IP again
|
||||||
|
hass.config.components.add("cloud")
|
||||||
|
assert async_user_not_allowed_do_auth(hass, user, trusted_request) is None
|
||||||
|
assert (
|
||||||
|
async_user_not_allowed_do_auth(hass, user, untrusted_request)
|
||||||
|
== "User cannot authenticate remotely"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Is Cloud request and local only, even a local IP will fail
|
||||||
|
with patch(
|
||||||
|
"hass_nabucasa.remote.is_cloud_request", Mock(get=Mock(return_value=True))
|
||||||
|
):
|
||||||
|
assert (
|
||||||
|
async_user_not_allowed_do_auth(hass, user, trusted_request)
|
||||||
|
== "User is local only"
|
||||||
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user