mirror of
https://github.com/home-assistant/core.git
synced 2025-07-24 21:57:51 +00:00
Enforce permissions for Websocket API (#18719)
* Handle unauth exceptions in websocket * Enforce permissions in websocket API
This commit is contained in:
parent
7248c9cb0e
commit
9d7b1fc3a7
@ -3,6 +3,7 @@ import voluptuous as vol
|
|||||||
|
|
||||||
from homeassistant.const import MATCH_ALL, EVENT_TIME_CHANGED
|
from homeassistant.const import MATCH_ALL, EVENT_TIME_CHANGED
|
||||||
from homeassistant.core import callback, DOMAIN as HASS_DOMAIN
|
from homeassistant.core import callback, DOMAIN as HASS_DOMAIN
|
||||||
|
from homeassistant.exceptions import Unauthorized
|
||||||
from homeassistant.helpers import config_validation as cv
|
from homeassistant.helpers import config_validation as cv
|
||||||
from homeassistant.helpers.service import async_get_all_descriptions
|
from homeassistant.helpers.service import async_get_all_descriptions
|
||||||
|
|
||||||
@ -98,6 +99,9 @@ def handle_subscribe_events(hass, connection, msg):
|
|||||||
|
|
||||||
Async friendly.
|
Async friendly.
|
||||||
"""
|
"""
|
||||||
|
if not connection.user.is_admin:
|
||||||
|
raise Unauthorized
|
||||||
|
|
||||||
async def forward_events(event):
|
async def forward_events(event):
|
||||||
"""Forward events to websocket."""
|
"""Forward events to websocket."""
|
||||||
if event.event_type == EVENT_TIME_CHANGED:
|
if event.event_type == EVENT_TIME_CHANGED:
|
||||||
@ -149,8 +153,14 @@ def handle_get_states(hass, connection, msg):
|
|||||||
|
|
||||||
Async friendly.
|
Async friendly.
|
||||||
"""
|
"""
|
||||||
|
entity_perm = connection.user.permissions.check_entity
|
||||||
|
states = [
|
||||||
|
state for state in hass.states.async_all()
|
||||||
|
if entity_perm(state.entity_id, 'read')
|
||||||
|
]
|
||||||
|
|
||||||
connection.send_message(messages.result_message(
|
connection.send_message(messages.result_message(
|
||||||
msg['id'], hass.states.async_all()))
|
msg['id'], states))
|
||||||
|
|
||||||
|
|
||||||
@decorators.async_response
|
@decorators.async_response
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.core import callback, Context
|
from homeassistant.core import callback, Context
|
||||||
|
from homeassistant.exceptions import Unauthorized
|
||||||
|
|
||||||
from . import const, messages
|
from . import const, messages
|
||||||
|
|
||||||
@ -63,11 +64,8 @@ class ActiveConnection:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
handler(self.hass, self, schema(msg))
|
handler(self.hass, self, schema(msg))
|
||||||
except Exception: # pylint: disable=broad-except
|
except Exception as err: # pylint: disable=broad-except
|
||||||
self.logger.exception('Error handling message: %s', msg)
|
self.async_handle_exception(msg, err)
|
||||||
self.send_message(messages.error_message(
|
|
||||||
cur_id, const.ERR_UNKNOWN_ERROR,
|
|
||||||
'Unknown error.'))
|
|
||||||
|
|
||||||
self.last_id = cur_id
|
self.last_id = cur_id
|
||||||
|
|
||||||
@ -76,3 +74,20 @@ class ActiveConnection:
|
|||||||
"""Close down connection."""
|
"""Close down connection."""
|
||||||
for unsub in self.event_listeners.values():
|
for unsub in self.event_listeners.values():
|
||||||
unsub()
|
unsub()
|
||||||
|
|
||||||
|
@callback
|
||||||
|
def async_handle_exception(self, msg, err):
|
||||||
|
"""Handle an exception while processing a handler."""
|
||||||
|
if isinstance(err, Unauthorized):
|
||||||
|
code = const.ERR_UNAUTHORIZED
|
||||||
|
err_message = 'Unauthorized'
|
||||||
|
elif isinstance(err, vol.Invalid):
|
||||||
|
code = const.ERR_INVALID_FORMAT
|
||||||
|
err_message = 'Invalid format'
|
||||||
|
else:
|
||||||
|
self.logger.exception('Error handling message: %s', msg)
|
||||||
|
code = const.ERR_UNKNOWN_ERROR
|
||||||
|
err_message = 'Unknown error'
|
||||||
|
|
||||||
|
self.send_message(
|
||||||
|
messages.error_message(msg['id'], code, err_message))
|
||||||
|
@ -6,11 +6,12 @@ DOMAIN = 'websocket_api'
|
|||||||
URL = '/api/websocket'
|
URL = '/api/websocket'
|
||||||
MAX_PENDING_MSG = 512
|
MAX_PENDING_MSG = 512
|
||||||
|
|
||||||
ERR_ID_REUSE = 1
|
ERR_ID_REUSE = 'id_reuse'
|
||||||
ERR_INVALID_FORMAT = 2
|
ERR_INVALID_FORMAT = 'invalid_format'
|
||||||
ERR_NOT_FOUND = 3
|
ERR_NOT_FOUND = 'not_found'
|
||||||
ERR_UNKNOWN_COMMAND = 4
|
ERR_UNKNOWN_COMMAND = 'unknown_command'
|
||||||
ERR_UNKNOWN_ERROR = 5
|
ERR_UNKNOWN_ERROR = 'unknown_error'
|
||||||
|
ERR_UNAUTHORIZED = 'unauthorized'
|
||||||
|
|
||||||
TYPE_RESULT = 'result'
|
TYPE_RESULT = 'result'
|
||||||
|
|
||||||
|
@ -14,10 +14,8 @@ async def _handle_async_response(func, hass, connection, msg):
|
|||||||
"""Create a response and handle exception."""
|
"""Create a response and handle exception."""
|
||||||
try:
|
try:
|
||||||
await func(hass, connection, msg)
|
await func(hass, connection, msg)
|
||||||
except Exception: # pylint: disable=broad-except
|
except Exception as err: # pylint: disable=broad-except
|
||||||
_LOGGER.exception("Unexpected exception")
|
connection.async_handle_exception(msg, err)
|
||||||
connection.send_message(messages.error_message(
|
|
||||||
msg['id'], 'unknown', 'Unexpected error occurred'))
|
|
||||||
|
|
||||||
|
|
||||||
def async_response(func):
|
def async_response(func):
|
||||||
|
@ -9,9 +9,10 @@ from . import API_PASSWORD
|
|||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def websocket_client(hass, hass_ws_client):
|
def websocket_client(hass, hass_ws_client, hass_access_token):
|
||||||
"""Create a websocket client."""
|
"""Create a websocket client."""
|
||||||
return hass.loop.run_until_complete(hass_ws_client(hass))
|
return hass.loop.run_until_complete(
|
||||||
|
hass_ws_client(hass, hass_access_token))
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -261,3 +261,42 @@ async def test_call_service_context_no_user(hass, aiohttp_client):
|
|||||||
assert call.service == 'test_service'
|
assert call.service == 'test_service'
|
||||||
assert call.data == {'hello': 'world'}
|
assert call.data == {'hello': 'world'}
|
||||||
assert call.context.user_id is None
|
assert call.context.user_id is None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_subscribe_requires_admin(websocket_client, hass_admin_user):
|
||||||
|
"""Test subscribing events without being admin."""
|
||||||
|
hass_admin_user.groups = []
|
||||||
|
await websocket_client.send_json({
|
||||||
|
'id': 5,
|
||||||
|
'type': commands.TYPE_SUBSCRIBE_EVENTS,
|
||||||
|
'event_type': 'test_event'
|
||||||
|
})
|
||||||
|
|
||||||
|
msg = await websocket_client.receive_json()
|
||||||
|
assert not msg['success']
|
||||||
|
assert msg['error']['code'] == const.ERR_UNAUTHORIZED
|
||||||
|
|
||||||
|
|
||||||
|
async def test_states_filters_visible(hass, hass_admin_user, websocket_client):
|
||||||
|
"""Test we only get entities that we're allowed to see."""
|
||||||
|
hass_admin_user.mock_policy({
|
||||||
|
'entities': {
|
||||||
|
'entity_ids': {
|
||||||
|
'test.entity': True
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
hass.states.async_set('test.entity', 'hello')
|
||||||
|
hass.states.async_set('test.not_visible_entity', 'invisible')
|
||||||
|
await websocket_client.send_json({
|
||||||
|
'id': 5,
|
||||||
|
'type': commands.TYPE_GET_STATES,
|
||||||
|
})
|
||||||
|
|
||||||
|
msg = await websocket_client.receive_json()
|
||||||
|
assert msg['id'] == 5
|
||||||
|
assert msg['type'] == const.TYPE_RESULT
|
||||||
|
assert msg['success']
|
||||||
|
|
||||||
|
assert len(msg['result']) == 1
|
||||||
|
assert msg['result'][0]['entity_id'] == 'test.entity'
|
||||||
|
Loading…
x
Reference in New Issue
Block a user