Fix hassio auth data (#39244)

Co-authored-by: Pascal Vizeli <pvizeli@syshack.ch>
This commit is contained in:
Paulus Schoutsen 2020-08-25 14:22:50 +02:00 committed by GitHub
parent 13df3bce1b
commit 9979e465aa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 66 additions and 95 deletions

View File

@ -30,7 +30,8 @@ def _disallow_id(conf: Dict[str, Any]) -> Dict[str, Any]:
CONFIG_SCHEMA = vol.All(AUTH_PROVIDER_SCHEMA, _disallow_id) CONFIG_SCHEMA = vol.All(AUTH_PROVIDER_SCHEMA, _disallow_id)
async def async_get_provider(hass: HomeAssistant) -> "HassAuthProvider": @callback
def async_get_provider(hass: HomeAssistant) -> "HassAuthProvider":
"""Get the provider.""" """Get the provider."""
for prv in hass.auth.auth_providers: for prv in hass.auth.auth_providers:
if prv.type == "homeassistant": if prv.type == "homeassistant":

View File

@ -30,7 +30,7 @@ async def async_setup(hass):
@websocket_api.async_response @websocket_api.async_response
async def websocket_create(hass, connection, msg): async def websocket_create(hass, connection, msg):
"""Create credentials and attach to a user.""" """Create credentials and attach to a user."""
provider = await auth_ha.async_get_provider(hass) provider = auth_ha.async_get_provider(hass)
user = await hass.auth.async_get_user(msg["user_id"]) user = await hass.auth.async_get_user(msg["user_id"])
if user is None: if user is None:
@ -77,7 +77,7 @@ async def websocket_create(hass, connection, msg):
@websocket_api.async_response @websocket_api.async_response
async def websocket_delete(hass, connection, msg): async def websocket_delete(hass, connection, msg):
"""Delete username and related credential.""" """Delete username and related credential."""
provider = await auth_ha.async_get_provider(hass) provider = auth_ha.async_get_provider(hass)
credentials = await provider.async_get_or_create_credentials( credentials = await provider.async_get_or_create_credentials(
{"username": msg["username"]} {"username": msg["username"]}
) )
@ -120,7 +120,7 @@ async def websocket_change_password(hass, connection, msg):
) )
return return
provider = await auth_ha.async_get_provider(hass) provider = auth_ha.async_get_provider(hass)
username = None username = None
for credential in user.credentials: for credential in user.credentials:
if credential.auth_provider_type == provider.type: if credential.auth_provider_type == provider.type:
@ -166,7 +166,7 @@ async def websocket_admin_change_password(hass, connection, msg):
if not connection.user.is_owner: if not connection.user.is_owner:
raise Unauthorized(context=connection.context(msg)) raise Unauthorized(context=connection.context(msg))
provider = await auth_ha.async_get_provider(hass) provider = auth_ha.async_get_provider(hass)
try: try:
await provider.async_change_password(msg["username"], msg["password"]) await provider.async_change_password(msg["username"], msg["password"])
connection.send_message(websocket_api.result_message(msg["id"])) connection.send_message(websocket_api.result_message(msg["id"]))

View File

@ -4,20 +4,16 @@ import logging
import os import os
from aiohttp import web from aiohttp import web
from aiohttp.web_exceptions import ( from aiohttp.web_exceptions import HTTPNotFound, HTTPUnauthorized
HTTPInternalServerError,
HTTPNotFound,
HTTPUnauthorized,
)
import voluptuous as vol import voluptuous as vol
from homeassistant.auth.models import User from homeassistant.auth.models import User
from homeassistant.auth.providers import homeassistant as auth_ha
from homeassistant.components.http import HomeAssistantView from homeassistant.components.http import HomeAssistantView
from homeassistant.components.http.const import KEY_HASS_USER from homeassistant.components.http.const import KEY_HASS_USER
from homeassistant.components.http.data_validator import RequestDataValidator from homeassistant.components.http.data_validator import RequestDataValidator
from homeassistant.const import HTTP_OK from homeassistant.const import HTTP_OK
from homeassistant.core import callback from homeassistant.core import callback
from homeassistant.exceptions import HomeAssistantError
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.typing import HomeAssistantType from homeassistant.helpers.typing import HomeAssistantType
@ -26,21 +22,6 @@ from .const import ATTR_ADDON, ATTR_PASSWORD, ATTR_USERNAME
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
SCHEMA_API_AUTH = vol.Schema(
{
vol.Required(ATTR_USERNAME): cv.string,
vol.Required(ATTR_PASSWORD): cv.string,
vol.Required(ATTR_ADDON): cv.string,
},
extra=vol.ALLOW_EXTRA,
)
SCHEMA_API_PASSWORD_RESET = vol.Schema(
{vol.Required(ATTR_USERNAME): cv.string, vol.Required(ATTR_PASSWORD): cv.string},
extra=vol.ALLOW_EXTRA,
)
@callback @callback
def async_setup_auth_view(hass: HomeAssistantType, user: User): def async_setup_auth_view(hass: HomeAssistantType, user: User):
"""Auth setup.""" """Auth setup."""
@ -74,15 +55,6 @@ class HassIOBaseAuth(HomeAssistantView):
_LOGGER.error("Invalid auth request from %s", request[KEY_HASS_USER].name) _LOGGER.error("Invalid auth request from %s", request[KEY_HASS_USER].name)
raise HTTPUnauthorized() raise HTTPUnauthorized()
def _get_provider(self):
"""Return Homeassistant auth provider."""
prv = self.hass.auth.get_auth_provider("homeassistant", None)
if prv is not None:
return prv
_LOGGER.error("Can't find Home Assistant auth")
raise HTTPNotFound()
class HassIOAuth(HassIOBaseAuth): class HassIOAuth(HassIOBaseAuth):
"""Hass.io view to handle auth requests.""" """Hass.io view to handle auth requests."""
@ -90,23 +62,30 @@ class HassIOAuth(HassIOBaseAuth):
name = "api:hassio:auth" name = "api:hassio:auth"
url = "/api/hassio_auth" url = "/api/hassio_auth"
@RequestDataValidator(SCHEMA_API_AUTH) @RequestDataValidator(
vol.Schema(
{
vol.Required(ATTR_USERNAME): cv.string,
vol.Required(ATTR_PASSWORD): cv.string,
vol.Required(ATTR_ADDON): cv.string,
},
extra=vol.ALLOW_EXTRA,
)
)
async def post(self, request, data): async def post(self, request, data):
"""Handle auth requests.""" """Handle auth requests."""
self._check_access(request) self._check_access(request)
provider = auth_ha.async_get_provider(request.app["hass"])
await self._check_login(data[ATTR_USERNAME], data[ATTR_PASSWORD])
return web.Response(status=HTTP_OK)
async def _check_login(self, username, password):
"""Check User credentials."""
provider = self._get_provider()
try: try:
await provider.async_validate_login(username, password) await provider.async_validate_login(
except HomeAssistantError: data[ATTR_USERNAME], data[ATTR_PASSWORD]
)
except auth_ha.InvalidAuth:
raise HTTPUnauthorized() from None raise HTTPUnauthorized() from None
return web.Response(status=HTTP_OK)
class HassIOPasswordReset(HassIOBaseAuth): class HassIOPasswordReset(HassIOBaseAuth):
"""Hass.io view to handle password reset requests.""" """Hass.io view to handle password reset requests."""
@ -114,22 +93,25 @@ class HassIOPasswordReset(HassIOBaseAuth):
name = "api:hassio:auth:password:reset" name = "api:hassio:auth:password:reset"
url = "/api/hassio_auth/password_reset" url = "/api/hassio_auth/password_reset"
@RequestDataValidator(SCHEMA_API_PASSWORD_RESET) @RequestDataValidator(
vol.Schema(
{
vol.Required(ATTR_USERNAME): cv.string,
vol.Required(ATTR_PASSWORD): cv.string,
},
extra=vol.ALLOW_EXTRA,
)
)
async def post(self, request, data): async def post(self, request, data):
"""Handle password reset requests.""" """Handle password reset requests."""
self._check_access(request) self._check_access(request)
provider = auth_ha.async_get_provider(request.app["hass"])
await self._change_password(data[ATTR_USERNAME], data[ATTR_PASSWORD])
return web.Response(status=HTTP_OK)
async def _change_password(self, username, password):
"""Check User credentials."""
provider = self._get_provider()
try: try:
await self.hass.async_add_executor_job( await provider.async_change_password(
provider.data.change_password, username, password data[ATTR_USERNAME], data[ATTR_PASSWORD]
) )
await provider.data.async_save() except auth_ha.InvalidUser:
except HomeAssistantError: raise HTTPNotFound()
raise HTTPInternalServerError()
return web.Response(status=HTTP_OK)

View File

@ -1,7 +1,6 @@
"""The tests for the hassio component.""" """The tests for the hassio component."""
from homeassistant.const import HTTP_INTERNAL_SERVER_ERROR from homeassistant.auth.providers.homeassistant import InvalidAuth
from homeassistant.exceptions import HomeAssistantError
from tests.async_mock import Mock, patch from tests.async_mock import Mock, patch
@ -59,7 +58,7 @@ async def test_login_error(hass, hassio_client_supervisor):
with patch( with patch(
"homeassistant.auth.providers.homeassistant." "homeassistant.auth.providers.homeassistant."
"HassAuthProvider.async_validate_login", "HassAuthProvider.async_validate_login",
Mock(side_effect=HomeAssistantError()), Mock(side_effect=InvalidAuth()),
) as mock_login: ) as mock_login:
resp = await hassio_client_supervisor.post( resp = await hassio_client_supervisor.post(
"/api/hassio_auth", "/api/hassio_auth",
@ -76,7 +75,7 @@ async def test_login_no_data(hass, hassio_client_supervisor):
with patch( with patch(
"homeassistant.auth.providers.homeassistant." "homeassistant.auth.providers.homeassistant."
"HassAuthProvider.async_validate_login", "HassAuthProvider.async_validate_login",
Mock(side_effect=HomeAssistantError()), Mock(side_effect=InvalidAuth()),
) as mock_login: ) as mock_login:
resp = await hassio_client_supervisor.post("/api/hassio_auth") resp = await hassio_client_supervisor.post("/api/hassio_auth")
@ -90,7 +89,7 @@ async def test_login_no_username(hass, hassio_client_supervisor):
with patch( with patch(
"homeassistant.auth.providers.homeassistant." "homeassistant.auth.providers.homeassistant."
"HassAuthProvider.async_validate_login", "HassAuthProvider.async_validate_login",
Mock(side_effect=HomeAssistantError()), Mock(side_effect=InvalidAuth()),
) as mock_login: ) as mock_login:
resp = await hassio_client_supervisor.post( resp = await hassio_client_supervisor.post(
"/api/hassio_auth", json={"password": "123456", "addon": "samba"} "/api/hassio_auth", json={"password": "123456", "addon": "samba"}
@ -125,7 +124,8 @@ async def test_login_success_extra(hass, hassio_client_supervisor):
async def test_password_success(hass, hassio_client_supervisor): async def test_password_success(hass, hassio_client_supervisor):
"""Test no auth needed for .""" """Test no auth needed for ."""
with patch( with patch(
"homeassistant.components.hassio.auth.HassIOPasswordReset._change_password", "homeassistant.auth.providers.homeassistant."
"HassAuthProvider.async_change_password",
) as mock_change: ) as mock_change:
resp = await hassio_client_supervisor.post( resp = await hassio_client_supervisor.post(
"/api/hassio_auth/password_reset", "/api/hassio_auth/password_reset",
@ -139,44 +139,32 @@ async def test_password_success(hass, hassio_client_supervisor):
async def test_password_fails_no_supervisor(hass, hassio_client): async def test_password_fails_no_supervisor(hass, hassio_client):
"""Test if only supervisor can access.""" """Test if only supervisor can access."""
with patch( resp = await hassio_client.post(
"homeassistant.auth.providers.homeassistant.Data.async_save", "/api/hassio_auth/password_reset",
) as mock_save: json={"username": "test", "password": "123456"},
resp = await hassio_client.post( )
"/api/hassio_auth/password_reset",
json={"username": "test", "password": "123456"},
)
# Check we got right response # Check we got right response
assert resp.status == 401 assert resp.status == 401
assert not mock_save.called
async def test_password_fails_no_auth(hass, hassio_noauth_client): async def test_password_fails_no_auth(hass, hassio_noauth_client):
"""Test if only supervisor can access.""" """Test if only supervisor can access."""
with patch( resp = await hassio_noauth_client.post(
"homeassistant.auth.providers.homeassistant.Data.async_save", "/api/hassio_auth/password_reset",
) as mock_save: json={"username": "test", "password": "123456"},
resp = await hassio_noauth_client.post( )
"/api/hassio_auth/password_reset",
json={"username": "test", "password": "123456"},
)
# Check we got right response # Check we got right response
assert resp.status == 401 assert resp.status == 401
assert not mock_save.called
async def test_password_no_user(hass, hassio_client_supervisor): async def test_password_no_user(hass, hassio_client_supervisor):
"""Test no auth needed for .""" """Test changing password for invalid user."""
with patch( resp = await hassio_client_supervisor.post(
"homeassistant.auth.providers.homeassistant.Data.async_save", "/api/hassio_auth/password_reset",
) as mock_save: json={"username": "test", "password": "123456"},
resp = await hassio_client_supervisor.post( )
"/api/hassio_auth/password_reset",
json={"username": "test", "password": "123456"},
)
# Check we got right response # Check we got right response
assert resp.status == HTTP_INTERNAL_SERVER_ERROR assert resp.status == 404
assert not mock_save.called