mirror of
https://github.com/home-assistant/core.git
synced 2025-04-25 01:38:02 +00:00
Break up websocket component (#17003)
* Break up websocket component * Lint
This commit is contained in:
parent
9edf1e5151
commit
22a80cf733
@ -452,27 +452,23 @@ class CameraMjpegStream(CameraView):
|
||||
raise web.HTTPBadRequest()
|
||||
|
||||
|
||||
@callback
|
||||
def websocket_camera_thumbnail(hass, connection, msg):
|
||||
@websocket_api.async_response
|
||||
async def websocket_camera_thumbnail(hass, connection, msg):
|
||||
"""Handle get camera thumbnail websocket command.
|
||||
|
||||
Async friendly.
|
||||
"""
|
||||
async def send_camera_still():
|
||||
"""Send a camera still."""
|
||||
try:
|
||||
image = await async_get_image(hass, msg['entity_id'])
|
||||
connection.send_message_outside(websocket_api.result_message(
|
||||
msg['id'], {
|
||||
'content_type': image.content_type,
|
||||
'content': base64.b64encode(image.content).decode('utf-8')
|
||||
}
|
||||
))
|
||||
except HomeAssistantError:
|
||||
connection.send_message_outside(websocket_api.error_message(
|
||||
msg['id'], 'image_fetch_failed', 'Unable to fetch image'))
|
||||
|
||||
hass.async_add_job(send_camera_still())
|
||||
try:
|
||||
image = await async_get_image(hass, msg['entity_id'])
|
||||
connection.send_message_outside(websocket_api.result_message(
|
||||
msg['id'], {
|
||||
'content_type': image.content_type,
|
||||
'content': base64.b64encode(image.content).decode('utf-8')
|
||||
}
|
||||
))
|
||||
except HomeAssistantError:
|
||||
connection.send_message_outside(websocket_api.error_message(
|
||||
msg['id'], 'image_fetch_failed', 'Unable to fetch image'))
|
||||
|
||||
|
||||
async def async_handle_snapshot_service(camera, service):
|
||||
|
@ -3,6 +3,7 @@ import voluptuous as vol
|
||||
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.components import websocket_api
|
||||
from homeassistant.components.websocket_api.decorators import require_owner
|
||||
|
||||
|
||||
WS_TYPE_LIST = 'config/auth/list'
|
||||
@ -41,7 +42,7 @@ async def async_setup(hass):
|
||||
|
||||
|
||||
@callback
|
||||
@websocket_api.require_owner
|
||||
@require_owner
|
||||
def websocket_list(hass, connection, msg):
|
||||
"""Return a list of users."""
|
||||
async def send_users():
|
||||
@ -55,7 +56,7 @@ def websocket_list(hass, connection, msg):
|
||||
|
||||
|
||||
@callback
|
||||
@websocket_api.require_owner
|
||||
@require_owner
|
||||
def websocket_delete(hass, connection, msg):
|
||||
"""Delete a user."""
|
||||
async def delete_user():
|
||||
@ -82,7 +83,7 @@ def websocket_delete(hass, connection, msg):
|
||||
|
||||
|
||||
@callback
|
||||
@websocket_api.require_owner
|
||||
@require_owner
|
||||
def websocket_create(hass, connection, msg):
|
||||
"""Create a user."""
|
||||
async def create_user():
|
||||
|
@ -4,6 +4,7 @@ import voluptuous as vol
|
||||
from homeassistant.auth.providers import homeassistant as auth_ha
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.components import websocket_api
|
||||
from homeassistant.components.websocket_api.decorators import require_owner
|
||||
|
||||
|
||||
WS_TYPE_CREATE = 'config/auth_provider/homeassistant/create'
|
||||
@ -55,7 +56,7 @@ def _get_provider(hass):
|
||||
|
||||
|
||||
@callback
|
||||
@websocket_api.require_owner
|
||||
@require_owner
|
||||
def websocket_create(hass, connection, msg):
|
||||
"""Create credentials and attach to a user."""
|
||||
async def create_creds():
|
||||
@ -96,7 +97,7 @@ def websocket_create(hass, connection, msg):
|
||||
|
||||
|
||||
@callback
|
||||
@websocket_api.require_owner
|
||||
@require_owner
|
||||
def websocket_delete(hass, connection, msg):
|
||||
"""Delete username and related credential."""
|
||||
async def delete_creds():
|
||||
|
@ -4,6 +4,8 @@ import voluptuous as vol
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.helpers.entity_registry import async_get_registry
|
||||
from homeassistant.components import websocket_api
|
||||
from homeassistant.components.websocket_api.const import ERR_NOT_FOUND
|
||||
from homeassistant.components.websocket_api.decorators import async_response
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
|
||||
DEPENDENCIES = ['websocket_api']
|
||||
@ -46,89 +48,77 @@ async def async_setup(hass):
|
||||
return True
|
||||
|
||||
|
||||
@callback
|
||||
def websocket_list_entities(hass, connection, msg):
|
||||
@async_response
|
||||
async def websocket_list_entities(hass, connection, msg):
|
||||
"""Handle list registry entries command.
|
||||
|
||||
Async friendly.
|
||||
"""
|
||||
async def retrieve_entities():
|
||||
"""Get entities from registry."""
|
||||
registry = await async_get_registry(hass)
|
||||
connection.send_message_outside(websocket_api.result_message(
|
||||
msg['id'], [{
|
||||
'config_entry_id': entry.config_entry_id,
|
||||
'device_id': entry.device_id,
|
||||
'disabled_by': entry.disabled_by,
|
||||
'entity_id': entry.entity_id,
|
||||
'name': entry.name,
|
||||
'platform': entry.platform,
|
||||
} for entry in registry.entities.values()]
|
||||
))
|
||||
|
||||
hass.async_add_job(retrieve_entities())
|
||||
registry = await async_get_registry(hass)
|
||||
connection.send_message_outside(websocket_api.result_message(
|
||||
msg['id'], [{
|
||||
'config_entry_id': entry.config_entry_id,
|
||||
'device_id': entry.device_id,
|
||||
'disabled_by': entry.disabled_by,
|
||||
'entity_id': entry.entity_id,
|
||||
'name': entry.name,
|
||||
'platform': entry.platform,
|
||||
} for entry in registry.entities.values()]
|
||||
))
|
||||
|
||||
|
||||
@callback
|
||||
def websocket_get_entity(hass, connection, msg):
|
||||
@async_response
|
||||
async def websocket_get_entity(hass, connection, msg):
|
||||
"""Handle get entity registry entry command.
|
||||
|
||||
Async friendly.
|
||||
"""
|
||||
async def retrieve_entity():
|
||||
"""Get entity from registry."""
|
||||
registry = await async_get_registry(hass)
|
||||
entry = registry.entities.get(msg['entity_id'])
|
||||
registry = await async_get_registry(hass)
|
||||
entry = registry.entities.get(msg['entity_id'])
|
||||
|
||||
if entry is None:
|
||||
connection.send_message_outside(websocket_api.error_message(
|
||||
msg['id'], websocket_api.ERR_NOT_FOUND, 'Entity not found'))
|
||||
return
|
||||
if entry is None:
|
||||
connection.send_message_outside(websocket_api.error_message(
|
||||
msg['id'], ERR_NOT_FOUND, 'Entity not found'))
|
||||
return
|
||||
|
||||
connection.send_message_outside(websocket_api.result_message(
|
||||
msg['id'], _entry_dict(entry)
|
||||
))
|
||||
|
||||
hass.async_add_job(retrieve_entity())
|
||||
connection.send_message_outside(websocket_api.result_message(
|
||||
msg['id'], _entry_dict(entry)
|
||||
))
|
||||
|
||||
|
||||
@callback
|
||||
def websocket_update_entity(hass, connection, msg):
|
||||
@async_response
|
||||
async def websocket_update_entity(hass, connection, msg):
|
||||
"""Handle get camera thumbnail websocket command.
|
||||
|
||||
Async friendly.
|
||||
"""
|
||||
async def update_entity():
|
||||
"""Get entity from registry."""
|
||||
registry = await async_get_registry(hass)
|
||||
registry = await async_get_registry(hass)
|
||||
|
||||
if msg['entity_id'] not in registry.entities:
|
||||
connection.send_message_outside(websocket_api.error_message(
|
||||
msg['id'], websocket_api.ERR_NOT_FOUND, 'Entity not found'))
|
||||
return
|
||||
if msg['entity_id'] not in registry.entities:
|
||||
connection.send_message_outside(websocket_api.error_message(
|
||||
msg['id'], ERR_NOT_FOUND, 'Entity not found'))
|
||||
return
|
||||
|
||||
changes = {}
|
||||
changes = {}
|
||||
|
||||
if 'name' in msg:
|
||||
changes['name'] = msg['name']
|
||||
if 'name' in msg:
|
||||
changes['name'] = msg['name']
|
||||
|
||||
if 'new_entity_id' in msg:
|
||||
changes['new_entity_id'] = msg['new_entity_id']
|
||||
if 'new_entity_id' in msg:
|
||||
changes['new_entity_id'] = msg['new_entity_id']
|
||||
|
||||
try:
|
||||
if changes:
|
||||
entry = registry.async_update_entity(
|
||||
msg['entity_id'], **changes)
|
||||
except ValueError as err:
|
||||
connection.send_message_outside(websocket_api.error_message(
|
||||
msg['id'], 'invalid_info', str(err)
|
||||
))
|
||||
else:
|
||||
connection.send_message_outside(websocket_api.result_message(
|
||||
msg['id'], _entry_dict(entry)
|
||||
))
|
||||
|
||||
hass.async_create_task(update_entity())
|
||||
try:
|
||||
if changes:
|
||||
entry = registry.async_update_entity(
|
||||
msg['entity_id'], **changes)
|
||||
except ValueError as err:
|
||||
connection.send_message_outside(websocket_api.error_message(
|
||||
msg['id'], 'invalid_info', str(err)
|
||||
))
|
||||
else:
|
||||
connection.send_message_outside(websocket_api.result_message(
|
||||
msg['id'], _entry_dict(entry)
|
||||
))
|
||||
|
||||
|
||||
@callback
|
||||
|
@ -28,7 +28,6 @@ from homeassistant.const import (
|
||||
SERVICE_TOGGLE, SERVICE_TURN_OFF, SERVICE_TURN_ON, SERVICE_VOLUME_DOWN,
|
||||
SERVICE_VOLUME_MUTE, SERVICE_VOLUME_SET, SERVICE_VOLUME_UP, STATE_IDLE,
|
||||
STATE_OFF, STATE_PLAYING, STATE_UNKNOWN)
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
import homeassistant.helpers.config_validation as cv
|
||||
from homeassistant.helpers.config_validation import PLATFORM_SCHEMA # noqa
|
||||
@ -865,8 +864,8 @@ class MediaPlayerImageView(HomeAssistantView):
|
||||
body=data, content_type=content_type, headers=headers)
|
||||
|
||||
|
||||
@callback
|
||||
def websocket_handle_thumbnail(hass, connection, msg):
|
||||
@websocket_api.async_response
|
||||
async def websocket_handle_thumbnail(hass, connection, msg):
|
||||
"""Handle get media player cover command.
|
||||
|
||||
Async friendly.
|
||||
@ -879,20 +878,16 @@ def websocket_handle_thumbnail(hass, connection, msg):
|
||||
msg['id'], 'entity_not_found', 'Entity not found'))
|
||||
return
|
||||
|
||||
async def send_image():
|
||||
"""Send image."""
|
||||
data, content_type = await player.async_get_media_image()
|
||||
data, content_type = await player.async_get_media_image()
|
||||
|
||||
if data is None:
|
||||
connection.send_message_outside(websocket_api.error_message(
|
||||
msg['id'], 'thumbnail_fetch_failed',
|
||||
'Failed to fetch thumbnail'))
|
||||
return
|
||||
if data is None:
|
||||
connection.send_message_outside(websocket_api.error_message(
|
||||
msg['id'], 'thumbnail_fetch_failed',
|
||||
'Failed to fetch thumbnail'))
|
||||
return
|
||||
|
||||
connection.send_message_outside(websocket_api.result_message(
|
||||
msg['id'], {
|
||||
'content_type': content_type,
|
||||
'content': base64.b64encode(data).decode('utf-8')
|
||||
}))
|
||||
|
||||
hass.async_add_job(send_image())
|
||||
connection.send_message_outside(websocket_api.result_message(
|
||||
msg['id'], {
|
||||
'content_type': content_type,
|
||||
'content': base64.b64encode(data).decode('utf-8')
|
||||
}))
|
||||
|
@ -7,7 +7,7 @@ https://developers.home-assistant.io/docs/external_api_websocket.html
|
||||
import asyncio
|
||||
from concurrent import futures
|
||||
from contextlib import suppress
|
||||
from functools import partial, wraps
|
||||
from functools import partial
|
||||
import json
|
||||
import logging
|
||||
|
||||
@ -15,20 +15,18 @@ from aiohttp import web
|
||||
import voluptuous as vol
|
||||
from voluptuous.humanize import humanize_error
|
||||
|
||||
from homeassistant.const import (
|
||||
MATCH_ALL, EVENT_TIME_CHANGED, EVENT_HOMEASSISTANT_STOP,
|
||||
__version__)
|
||||
from homeassistant.core import Context, callback, HomeAssistant
|
||||
from homeassistant.const import EVENT_HOMEASSISTANT_STOP, __version__
|
||||
from homeassistant.core import Context, callback
|
||||
from homeassistant.loader import bind_hass
|
||||
from homeassistant.helpers.json import JSONEncoder
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.service import async_get_all_descriptions
|
||||
from homeassistant.components.http import HomeAssistantView
|
||||
from homeassistant.components.http.auth import validate_password
|
||||
from homeassistant.components.http.const import KEY_AUTHENTICATED
|
||||
from homeassistant.components.http.ban import process_wrong_login, \
|
||||
process_success_login
|
||||
|
||||
from . import commands, const, decorators, messages
|
||||
|
||||
DOMAIN = 'websocket_api'
|
||||
|
||||
URL = '/api/websocket'
|
||||
@ -36,87 +34,32 @@ DEPENDENCIES = ('http',)
|
||||
|
||||
MAX_PENDING_MSG = 512
|
||||
|
||||
ERR_ID_REUSE = 1
|
||||
ERR_INVALID_FORMAT = 2
|
||||
ERR_NOT_FOUND = 3
|
||||
ERR_UNKNOWN_COMMAND = 4
|
||||
ERR_UNKNOWN_ERROR = 5
|
||||
|
||||
TYPE_AUTH = 'auth'
|
||||
TYPE_AUTH_INVALID = 'auth_invalid'
|
||||
TYPE_AUTH_OK = 'auth_ok'
|
||||
TYPE_AUTH_REQUIRED = 'auth_required'
|
||||
TYPE_CALL_SERVICE = 'call_service'
|
||||
TYPE_EVENT = 'event'
|
||||
TYPE_GET_CONFIG = 'get_config'
|
||||
TYPE_GET_SERVICES = 'get_services'
|
||||
TYPE_GET_STATES = 'get_states'
|
||||
TYPE_PING = 'ping'
|
||||
TYPE_PONG = 'pong'
|
||||
TYPE_RESULT = 'result'
|
||||
TYPE_SUBSCRIBE_EVENTS = 'subscribe_events'
|
||||
TYPE_UNSUBSCRIBE_EVENTS = 'unsubscribe_events'
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
JSON_DUMP = partial(json.dumps, cls=JSONEncoder)
|
||||
|
||||
TYPE_AUTH = 'auth'
|
||||
TYPE_AUTH_INVALID = 'auth_invalid'
|
||||
TYPE_AUTH_OK = 'auth_ok'
|
||||
TYPE_AUTH_REQUIRED = 'auth_required'
|
||||
|
||||
|
||||
# Backwards compat
|
||||
# pylint: disable=invalid-name
|
||||
BASE_COMMAND_MESSAGE_SCHEMA = messages.BASE_COMMAND_MESSAGE_SCHEMA
|
||||
error_message = messages.error_message
|
||||
result_message = messages.result_message
|
||||
async_response = decorators.async_response
|
||||
ws_require_user = decorators.ws_require_user
|
||||
# pylint: enable=invalid-name
|
||||
|
||||
AUTH_MESSAGE_SCHEMA = vol.Schema({
|
||||
vol.Required('type'): TYPE_AUTH,
|
||||
vol.Exclusive('api_password', 'auth'): str,
|
||||
vol.Exclusive('access_token', 'auth'): str,
|
||||
})
|
||||
|
||||
# Minimal requirements of a message
|
||||
MINIMAL_MESSAGE_SCHEMA = vol.Schema({
|
||||
vol.Required('id'): cv.positive_int,
|
||||
vol.Required('type'): cv.string,
|
||||
}, extra=vol.ALLOW_EXTRA)
|
||||
# Base schema to extend by message handlers
|
||||
BASE_COMMAND_MESSAGE_SCHEMA = vol.Schema({
|
||||
vol.Required('id'): cv.positive_int,
|
||||
})
|
||||
|
||||
|
||||
SCHEMA_SUBSCRIBE_EVENTS = BASE_COMMAND_MESSAGE_SCHEMA.extend({
|
||||
vol.Required('type'): TYPE_SUBSCRIBE_EVENTS,
|
||||
vol.Optional('event_type', default=MATCH_ALL): str,
|
||||
})
|
||||
|
||||
|
||||
SCHEMA_UNSUBSCRIBE_EVENTS = BASE_COMMAND_MESSAGE_SCHEMA.extend({
|
||||
vol.Required('type'): TYPE_UNSUBSCRIBE_EVENTS,
|
||||
vol.Required('subscription'): cv.positive_int,
|
||||
})
|
||||
|
||||
|
||||
SCHEMA_CALL_SERVICE = BASE_COMMAND_MESSAGE_SCHEMA.extend({
|
||||
vol.Required('type'): TYPE_CALL_SERVICE,
|
||||
vol.Required('domain'): str,
|
||||
vol.Required('service'): str,
|
||||
vol.Optional('service_data'): dict
|
||||
})
|
||||
|
||||
|
||||
SCHEMA_GET_STATES = BASE_COMMAND_MESSAGE_SCHEMA.extend({
|
||||
vol.Required('type'): TYPE_GET_STATES,
|
||||
})
|
||||
|
||||
|
||||
SCHEMA_GET_SERVICES = BASE_COMMAND_MESSAGE_SCHEMA.extend({
|
||||
vol.Required('type'): TYPE_GET_SERVICES,
|
||||
})
|
||||
|
||||
|
||||
SCHEMA_GET_CONFIG = BASE_COMMAND_MESSAGE_SCHEMA.extend({
|
||||
vol.Required('type'): TYPE_GET_CONFIG,
|
||||
})
|
||||
|
||||
|
||||
SCHEMA_PING = BASE_COMMAND_MESSAGE_SCHEMA.extend({
|
||||
vol.Required('type'): TYPE_PING,
|
||||
})
|
||||
|
||||
|
||||
# Define the possible errors that occur when connections are cancelled.
|
||||
# Originally, this was just asyncio.CancelledError, but issue #9546 showed
|
||||
@ -148,46 +91,6 @@ def auth_invalid_message(message):
|
||||
}
|
||||
|
||||
|
||||
def event_message(iden, event):
|
||||
"""Return an event message."""
|
||||
return {
|
||||
'id': iden,
|
||||
'type': TYPE_EVENT,
|
||||
'event': event.as_dict(),
|
||||
}
|
||||
|
||||
|
||||
def error_message(iden, code, message):
|
||||
"""Return an error result message."""
|
||||
return {
|
||||
'id': iden,
|
||||
'type': TYPE_RESULT,
|
||||
'success': False,
|
||||
'error': {
|
||||
'code': code,
|
||||
'message': message,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def pong_message(iden):
|
||||
"""Return a pong message."""
|
||||
return {
|
||||
'id': iden,
|
||||
'type': TYPE_PONG,
|
||||
}
|
||||
|
||||
|
||||
def result_message(iden, result=None):
|
||||
"""Return a success result message."""
|
||||
return {
|
||||
'id': iden,
|
||||
'type': TYPE_RESULT,
|
||||
'success': True,
|
||||
'result': result,
|
||||
}
|
||||
|
||||
|
||||
@bind_hass
|
||||
@callback
|
||||
def async_register_command(hass, command, handler, schema):
|
||||
@ -198,43 +101,10 @@ def async_register_command(hass, command, handler, schema):
|
||||
handlers[command] = (handler, schema)
|
||||
|
||||
|
||||
def require_owner(func):
|
||||
"""Websocket decorator to require user to be an owner."""
|
||||
@wraps(func)
|
||||
def with_owner(hass, connection, msg):
|
||||
"""Check owner and call function."""
|
||||
user = connection.request.get('hass_user')
|
||||
|
||||
if user is None or not user.is_owner:
|
||||
connection.to_write.put_nowait(error_message(
|
||||
msg['id'], 'unauthorized', 'This command is for owners only.'))
|
||||
return
|
||||
|
||||
func(hass, connection, msg)
|
||||
|
||||
return with_owner
|
||||
|
||||
|
||||
async def async_setup(hass, config):
|
||||
"""Initialize the websocket API."""
|
||||
hass.http.register_view(WebsocketAPIView)
|
||||
|
||||
async_register_command(hass, TYPE_SUBSCRIBE_EVENTS,
|
||||
handle_subscribe_events, SCHEMA_SUBSCRIBE_EVENTS)
|
||||
async_register_command(hass, TYPE_UNSUBSCRIBE_EVENTS,
|
||||
handle_unsubscribe_events,
|
||||
SCHEMA_UNSUBSCRIBE_EVENTS)
|
||||
async_register_command(hass, TYPE_CALL_SERVICE,
|
||||
handle_call_service, SCHEMA_CALL_SERVICE)
|
||||
async_register_command(hass, TYPE_GET_STATES,
|
||||
handle_get_states, SCHEMA_GET_STATES)
|
||||
async_register_command(hass, TYPE_GET_SERVICES,
|
||||
handle_get_services, SCHEMA_GET_SERVICES)
|
||||
async_register_command(hass, TYPE_GET_CONFIG,
|
||||
handle_get_config, SCHEMA_GET_CONFIG)
|
||||
async_register_command(hass, TYPE_PING,
|
||||
handle_ping, SCHEMA_PING)
|
||||
|
||||
commands.async_register_commands(hass)
|
||||
return True
|
||||
|
||||
|
||||
@ -389,19 +259,19 @@ class ActiveConnection:
|
||||
|
||||
while msg:
|
||||
self.debug("Received", msg)
|
||||
msg = MINIMAL_MESSAGE_SCHEMA(msg)
|
||||
msg = messages.MINIMAL_MESSAGE_SCHEMA(msg)
|
||||
cur_id = msg['id']
|
||||
|
||||
if cur_id <= last_id:
|
||||
self.to_write.put_nowait(error_message(
|
||||
cur_id, ERR_ID_REUSE,
|
||||
self.to_write.put_nowait(messages.error_message(
|
||||
cur_id, const.ERR_ID_REUSE,
|
||||
'Identifier values have to increase.'))
|
||||
|
||||
elif msg['type'] not in handlers:
|
||||
self.log_error(
|
||||
'Received invalid command: {}'.format(msg['type']))
|
||||
self.to_write.put_nowait(error_message(
|
||||
cur_id, ERR_UNKNOWN_COMMAND,
|
||||
self.to_write.put_nowait(messages.error_message(
|
||||
cur_id, const.ERR_UNKNOWN_COMMAND,
|
||||
'Unknown command.'))
|
||||
|
||||
else:
|
||||
@ -410,8 +280,8 @@ class ActiveConnection:
|
||||
handler(self.hass, self, schema(msg))
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception('Error handling message: %s', msg)
|
||||
self.to_write.put_nowait(error_message(
|
||||
cur_id, ERR_UNKNOWN_ERROR,
|
||||
self.to_write.put_nowait(messages.error_message(
|
||||
cur_id, const.ERR_UNKNOWN_ERROR,
|
||||
'Unknown error.'))
|
||||
|
||||
last_id = cur_id
|
||||
@ -435,8 +305,8 @@ class ActiveConnection:
|
||||
else:
|
||||
iden = None
|
||||
|
||||
final_message = error_message(
|
||||
iden, ERR_INVALID_FORMAT, error_msg)
|
||||
final_message = messages.error_message(
|
||||
iden, const.ERR_INVALID_FORMAT, error_msg)
|
||||
|
||||
except TypeError as err:
|
||||
if wsock.closed:
|
||||
@ -485,170 +355,3 @@ class ActiveConnection:
|
||||
self.debug("Closed connection")
|
||||
|
||||
return wsock
|
||||
|
||||
|
||||
def async_response(func):
|
||||
"""Decorate an async function to handle WebSocket API messages."""
|
||||
async def handle_msg_response(hass, connection, msg):
|
||||
"""Create a response and handle exception."""
|
||||
try:
|
||||
await func(hass, connection, msg)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception("Unexpected exception")
|
||||
connection.send_message_outside(error_message(
|
||||
msg['id'], 'unknown', 'Unexpected error occurred'))
|
||||
|
||||
@callback
|
||||
@wraps(func)
|
||||
def schedule_handler(hass, connection, msg):
|
||||
"""Schedule the handler."""
|
||||
hass.async_create_task(handle_msg_response(hass, connection, msg))
|
||||
|
||||
return schedule_handler
|
||||
|
||||
|
||||
@callback
|
||||
def handle_subscribe_events(hass, connection, msg):
|
||||
"""Handle subscribe events command.
|
||||
|
||||
Async friendly.
|
||||
"""
|
||||
async def forward_events(event):
|
||||
"""Forward events to websocket."""
|
||||
if event.event_type == EVENT_TIME_CHANGED:
|
||||
return
|
||||
|
||||
connection.send_message_outside(event_message(msg['id'], event))
|
||||
|
||||
connection.event_listeners[msg['id']] = hass.bus.async_listen(
|
||||
msg['event_type'], forward_events)
|
||||
|
||||
connection.to_write.put_nowait(result_message(msg['id']))
|
||||
|
||||
|
||||
@callback
|
||||
def handle_unsubscribe_events(hass, connection, msg):
|
||||
"""Handle unsubscribe events command.
|
||||
|
||||
Async friendly.
|
||||
"""
|
||||
subscription = msg['subscription']
|
||||
|
||||
if subscription in connection.event_listeners:
|
||||
connection.event_listeners.pop(subscription)()
|
||||
connection.to_write.put_nowait(result_message(msg['id']))
|
||||
else:
|
||||
connection.to_write.put_nowait(error_message(
|
||||
msg['id'], ERR_NOT_FOUND, 'Subscription not found.'))
|
||||
|
||||
|
||||
@async_response
|
||||
async def handle_call_service(hass, connection, msg):
|
||||
"""Handle call service command.
|
||||
|
||||
Async friendly.
|
||||
"""
|
||||
blocking = True
|
||||
if (msg['domain'] == 'homeassistant' and
|
||||
msg['service'] in ['restart', 'stop']):
|
||||
blocking = False
|
||||
await hass.services.async_call(
|
||||
msg['domain'], msg['service'], msg.get('service_data'), blocking,
|
||||
connection.context(msg))
|
||||
connection.send_message_outside(result_message(msg['id']))
|
||||
|
||||
|
||||
@callback
|
||||
def handle_get_states(hass, connection, msg):
|
||||
"""Handle get states command.
|
||||
|
||||
Async friendly.
|
||||
"""
|
||||
connection.to_write.put_nowait(result_message(
|
||||
msg['id'], hass.states.async_all()))
|
||||
|
||||
|
||||
@async_response
|
||||
async def handle_get_services(hass, connection, msg):
|
||||
"""Handle get services command.
|
||||
|
||||
Async friendly.
|
||||
"""
|
||||
descriptions = await async_get_all_descriptions(hass)
|
||||
connection.send_message_outside(
|
||||
result_message(msg['id'], descriptions))
|
||||
|
||||
|
||||
@callback
|
||||
def handle_get_config(hass, connection, msg):
|
||||
"""Handle get config command.
|
||||
|
||||
Async friendly.
|
||||
"""
|
||||
connection.to_write.put_nowait(result_message(
|
||||
msg['id'], hass.config.as_dict()))
|
||||
|
||||
|
||||
@callback
|
||||
def handle_ping(hass, connection, msg):
|
||||
"""Handle ping command.
|
||||
|
||||
Async friendly.
|
||||
"""
|
||||
connection.to_write.put_nowait(pong_message(msg['id']))
|
||||
|
||||
|
||||
def ws_require_user(
|
||||
only_owner=False, only_system_user=False, allow_system_user=True,
|
||||
only_active_user=True, only_inactive_user=False):
|
||||
"""Decorate function validating login user exist in current WS connection.
|
||||
|
||||
Will write out error message if not authenticated.
|
||||
"""
|
||||
def validator(func):
|
||||
"""Decorate func."""
|
||||
@wraps(func)
|
||||
def check_current_user(hass: HomeAssistant,
|
||||
connection: ActiveConnection,
|
||||
msg):
|
||||
"""Check current user."""
|
||||
def output_error(message_id, message):
|
||||
"""Output error message."""
|
||||
connection.send_message_outside(error_message(
|
||||
msg['id'], message_id, message))
|
||||
|
||||
if connection.user is None:
|
||||
output_error('no_user', 'Not authenticated as a user')
|
||||
return
|
||||
|
||||
if only_owner and not connection.user.is_owner:
|
||||
output_error('only_owner', 'Only allowed as owner')
|
||||
return
|
||||
|
||||
if (only_system_user and
|
||||
not connection.user.system_generated):
|
||||
output_error('only_system_user',
|
||||
'Only allowed as system user')
|
||||
return
|
||||
|
||||
if (not allow_system_user
|
||||
and connection.user.system_generated):
|
||||
output_error('not_system_user', 'Not allowed as system user')
|
||||
return
|
||||
|
||||
if (only_active_user and
|
||||
not connection.user.is_active):
|
||||
output_error('only_active_user',
|
||||
'Only allowed as active user')
|
||||
return
|
||||
|
||||
if only_inactive_user and connection.user.is_active:
|
||||
output_error('only_inactive_user',
|
||||
'Not allowed as active user')
|
||||
return
|
||||
|
||||
return func(hass, connection, msg)
|
||||
|
||||
return check_current_user
|
||||
|
||||
return validator
|
183
homeassistant/components/websocket_api/commands.py
Normal file
183
homeassistant/components/websocket_api/commands.py
Normal file
@ -0,0 +1,183 @@
|
||||
"""Commands part of Websocket API."""
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.const import MATCH_ALL, EVENT_TIME_CHANGED
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.service import async_get_all_descriptions
|
||||
|
||||
from . import const, decorators, messages
|
||||
|
||||
|
||||
TYPE_CALL_SERVICE = 'call_service'
|
||||
TYPE_EVENT = 'event'
|
||||
TYPE_GET_CONFIG = 'get_config'
|
||||
TYPE_GET_SERVICES = 'get_services'
|
||||
TYPE_GET_STATES = 'get_states'
|
||||
TYPE_PING = 'ping'
|
||||
TYPE_PONG = 'pong'
|
||||
TYPE_SUBSCRIBE_EVENTS = 'subscribe_events'
|
||||
TYPE_UNSUBSCRIBE_EVENTS = 'unsubscribe_events'
|
||||
|
||||
|
||||
@callback
|
||||
def async_register_commands(hass):
|
||||
"""Register commands."""
|
||||
async_reg = hass.components.websocket_api.async_register_command
|
||||
async_reg(TYPE_SUBSCRIBE_EVENTS, handle_subscribe_events,
|
||||
SCHEMA_SUBSCRIBE_EVENTS)
|
||||
async_reg(TYPE_UNSUBSCRIBE_EVENTS, handle_unsubscribe_events,
|
||||
SCHEMA_UNSUBSCRIBE_EVENTS)
|
||||
async_reg(TYPE_CALL_SERVICE, handle_call_service, SCHEMA_CALL_SERVICE)
|
||||
async_reg(TYPE_GET_STATES, handle_get_states, SCHEMA_GET_STATES)
|
||||
async_reg(TYPE_GET_SERVICES, handle_get_services, SCHEMA_GET_SERVICES)
|
||||
async_reg(TYPE_GET_CONFIG, handle_get_config, SCHEMA_GET_CONFIG)
|
||||
async_reg(TYPE_PING, handle_ping, SCHEMA_PING)
|
||||
|
||||
|
||||
SCHEMA_SUBSCRIBE_EVENTS = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend({
|
||||
vol.Required('type'): TYPE_SUBSCRIBE_EVENTS,
|
||||
vol.Optional('event_type', default=MATCH_ALL): str,
|
||||
})
|
||||
|
||||
|
||||
SCHEMA_UNSUBSCRIBE_EVENTS = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend({
|
||||
vol.Required('type'): TYPE_UNSUBSCRIBE_EVENTS,
|
||||
vol.Required('subscription'): cv.positive_int,
|
||||
})
|
||||
|
||||
|
||||
SCHEMA_CALL_SERVICE = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend({
|
||||
vol.Required('type'): TYPE_CALL_SERVICE,
|
||||
vol.Required('domain'): str,
|
||||
vol.Required('service'): str,
|
||||
vol.Optional('service_data'): dict
|
||||
})
|
||||
|
||||
|
||||
SCHEMA_GET_STATES = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend({
|
||||
vol.Required('type'): TYPE_GET_STATES,
|
||||
})
|
||||
|
||||
|
||||
SCHEMA_GET_SERVICES = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend({
|
||||
vol.Required('type'): TYPE_GET_SERVICES,
|
||||
})
|
||||
|
||||
|
||||
SCHEMA_GET_CONFIG = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend({
|
||||
vol.Required('type'): TYPE_GET_CONFIG,
|
||||
})
|
||||
|
||||
|
||||
SCHEMA_PING = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend({
|
||||
vol.Required('type'): TYPE_PING,
|
||||
})
|
||||
|
||||
|
||||
def event_message(iden, event):
|
||||
"""Return an event message."""
|
||||
return {
|
||||
'id': iden,
|
||||
'type': TYPE_EVENT,
|
||||
'event': event.as_dict(),
|
||||
}
|
||||
|
||||
|
||||
def pong_message(iden):
|
||||
"""Return a pong message."""
|
||||
return {
|
||||
'id': iden,
|
||||
'type': TYPE_PONG,
|
||||
}
|
||||
|
||||
|
||||
@callback
|
||||
def handle_subscribe_events(hass, connection, msg):
|
||||
"""Handle subscribe events command.
|
||||
|
||||
Async friendly.
|
||||
"""
|
||||
async def forward_events(event):
|
||||
"""Forward events to websocket."""
|
||||
if event.event_type == EVENT_TIME_CHANGED:
|
||||
return
|
||||
|
||||
connection.send_message_outside(event_message(msg['id'], event))
|
||||
|
||||
connection.event_listeners[msg['id']] = hass.bus.async_listen(
|
||||
msg['event_type'], forward_events)
|
||||
|
||||
connection.to_write.put_nowait(messages.result_message(msg['id']))
|
||||
|
||||
|
||||
@callback
|
||||
def handle_unsubscribe_events(hass, connection, msg):
|
||||
"""Handle unsubscribe events command.
|
||||
|
||||
Async friendly.
|
||||
"""
|
||||
subscription = msg['subscription']
|
||||
|
||||
if subscription in connection.event_listeners:
|
||||
connection.event_listeners.pop(subscription)()
|
||||
connection.to_write.put_nowait(messages.result_message(msg['id']))
|
||||
else:
|
||||
connection.to_write.put_nowait(messages.error_message(
|
||||
msg['id'], const.ERR_NOT_FOUND, 'Subscription not found.'))
|
||||
|
||||
|
||||
@decorators.async_response
|
||||
async def handle_call_service(hass, connection, msg):
|
||||
"""Handle call service command.
|
||||
|
||||
Async friendly.
|
||||
"""
|
||||
blocking = True
|
||||
if (msg['domain'] == 'homeassistant' and
|
||||
msg['service'] in ['restart', 'stop']):
|
||||
blocking = False
|
||||
await hass.services.async_call(
|
||||
msg['domain'], msg['service'], msg.get('service_data'), blocking,
|
||||
connection.context(msg))
|
||||
connection.send_message_outside(messages.result_message(msg['id']))
|
||||
|
||||
|
||||
@callback
|
||||
def handle_get_states(hass, connection, msg):
|
||||
"""Handle get states command.
|
||||
|
||||
Async friendly.
|
||||
"""
|
||||
connection.to_write.put_nowait(messages.result_message(
|
||||
msg['id'], hass.states.async_all()))
|
||||
|
||||
|
||||
@decorators.async_response
|
||||
async def handle_get_services(hass, connection, msg):
|
||||
"""Handle get services command.
|
||||
|
||||
Async friendly.
|
||||
"""
|
||||
descriptions = await async_get_all_descriptions(hass)
|
||||
connection.send_message_outside(
|
||||
messages.result_message(msg['id'], descriptions))
|
||||
|
||||
|
||||
@callback
|
||||
def handle_get_config(hass, connection, msg):
|
||||
"""Handle get config command.
|
||||
|
||||
Async friendly.
|
||||
"""
|
||||
connection.to_write.put_nowait(messages.result_message(
|
||||
msg['id'], hass.config.as_dict()))
|
||||
|
||||
|
||||
@callback
|
||||
def handle_ping(hass, connection, msg):
|
||||
"""Handle ping command.
|
||||
|
||||
Async friendly.
|
||||
"""
|
||||
connection.to_write.put_nowait(pong_message(msg['id']))
|
8
homeassistant/components/websocket_api/const.py
Normal file
8
homeassistant/components/websocket_api/const.py
Normal file
@ -0,0 +1,8 @@
|
||||
"""Websocket constants."""
|
||||
ERR_ID_REUSE = 1
|
||||
ERR_INVALID_FORMAT = 2
|
||||
ERR_NOT_FOUND = 3
|
||||
ERR_UNKNOWN_COMMAND = 4
|
||||
ERR_UNKNOWN_ERROR = 5
|
||||
|
||||
TYPE_RESULT = 'result'
|
101
homeassistant/components/websocket_api/decorators.py
Normal file
101
homeassistant/components/websocket_api/decorators.py
Normal file
@ -0,0 +1,101 @@
|
||||
"""Decorators for the Websocket API."""
|
||||
from functools import wraps
|
||||
import logging
|
||||
|
||||
from homeassistant.core import callback
|
||||
|
||||
from . import messages
|
||||
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def async_response(func):
|
||||
"""Decorate an async function to handle WebSocket API messages."""
|
||||
async def handle_msg_response(hass, connection, msg):
|
||||
"""Create a response and handle exception."""
|
||||
try:
|
||||
await func(hass, connection, msg)
|
||||
except Exception: # pylint: disable=broad-except
|
||||
_LOGGER.exception("Unexpected exception")
|
||||
connection.send_message_outside(messages.error_message(
|
||||
msg['id'], 'unknown', 'Unexpected error occurred'))
|
||||
|
||||
@callback
|
||||
@wraps(func)
|
||||
def schedule_handler(hass, connection, msg):
|
||||
"""Schedule the handler."""
|
||||
hass.async_create_task(handle_msg_response(hass, connection, msg))
|
||||
|
||||
return schedule_handler
|
||||
|
||||
|
||||
def require_owner(func):
|
||||
"""Websocket decorator to require user to be an owner."""
|
||||
@wraps(func)
|
||||
def with_owner(hass, connection, msg):
|
||||
"""Check owner and call function."""
|
||||
user = connection.request.get('hass_user')
|
||||
|
||||
if user is None or not user.is_owner:
|
||||
connection.to_write.put_nowait(messages.error_message(
|
||||
msg['id'], 'unauthorized', 'This command is for owners only.'))
|
||||
return
|
||||
|
||||
func(hass, connection, msg)
|
||||
|
||||
return with_owner
|
||||
|
||||
|
||||
def ws_require_user(
|
||||
only_owner=False, only_system_user=False, allow_system_user=True,
|
||||
only_active_user=True, only_inactive_user=False):
|
||||
"""Decorate function validating login user exist in current WS connection.
|
||||
|
||||
Will write out error message if not authenticated.
|
||||
"""
|
||||
def validator(func):
|
||||
"""Decorate func."""
|
||||
@wraps(func)
|
||||
def check_current_user(hass, connection, msg):
|
||||
"""Check current user."""
|
||||
def output_error(message_id, message):
|
||||
"""Output error message."""
|
||||
connection.send_message_outside(messages.error_message(
|
||||
msg['id'], message_id, message))
|
||||
|
||||
if connection.user is None:
|
||||
output_error('no_user', 'Not authenticated as a user')
|
||||
return
|
||||
|
||||
if only_owner and not connection.user.is_owner:
|
||||
output_error('only_owner', 'Only allowed as owner')
|
||||
return
|
||||
|
||||
if (only_system_user and
|
||||
not connection.user.system_generated):
|
||||
output_error('only_system_user',
|
||||
'Only allowed as system user')
|
||||
return
|
||||
|
||||
if (not allow_system_user
|
||||
and connection.user.system_generated):
|
||||
output_error('not_system_user', 'Not allowed as system user')
|
||||
return
|
||||
|
||||
if (only_active_user and
|
||||
not connection.user.is_active):
|
||||
output_error('only_active_user',
|
||||
'Only allowed as active user')
|
||||
return
|
||||
|
||||
if only_inactive_user and connection.user.is_active:
|
||||
output_error('only_inactive_user',
|
||||
'Not allowed as active user')
|
||||
return
|
||||
|
||||
return func(hass, connection, msg)
|
||||
|
||||
return check_current_user
|
||||
|
||||
return validator
|
42
homeassistant/components/websocket_api/messages.py
Normal file
42
homeassistant/components/websocket_api/messages.py
Normal file
@ -0,0 +1,42 @@
|
||||
"""Message templates for websocket commands."""
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
|
||||
from . import const
|
||||
|
||||
|
||||
# Minimal requirements of a message
|
||||
MINIMAL_MESSAGE_SCHEMA = vol.Schema({
|
||||
vol.Required('id'): cv.positive_int,
|
||||
vol.Required('type'): cv.string,
|
||||
}, extra=vol.ALLOW_EXTRA)
|
||||
|
||||
# Base schema to extend by message handlers
|
||||
BASE_COMMAND_MESSAGE_SCHEMA = vol.Schema({
|
||||
vol.Required('id'): cv.positive_int,
|
||||
})
|
||||
|
||||
|
||||
def result_message(iden, result=None):
|
||||
"""Return a success result message."""
|
||||
return {
|
||||
'id': iden,
|
||||
'type': const.TYPE_RESULT,
|
||||
'success': True,
|
||||
'result': result,
|
||||
}
|
||||
|
||||
|
||||
def error_message(iden, code, message):
|
||||
"""Return an error result message."""
|
||||
return {
|
||||
'id': iden,
|
||||
'type': const.TYPE_RESULT,
|
||||
'success': False,
|
||||
'error': {
|
||||
'code': code,
|
||||
'message': message,
|
||||
},
|
||||
}
|
@ -7,7 +7,8 @@ import pytest
|
||||
|
||||
from homeassistant.setup import setup_component, async_setup_component
|
||||
from homeassistant.const import ATTR_ENTITY_PICTURE
|
||||
from homeassistant.components import camera, http, websocket_api
|
||||
from homeassistant.components import camera, http
|
||||
from homeassistant.components.websocket_api.const import TYPE_RESULT
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.util.async_ import run_coroutine_threadsafe
|
||||
|
||||
@ -150,7 +151,7 @@ async def test_webocket_camera_thumbnail(hass, hass_ws_client, mock_camera):
|
||||
msg = await client.receive_json()
|
||||
|
||||
assert msg['id'] == 5
|
||||
assert msg['type'] == websocket_api.TYPE_RESULT
|
||||
assert msg['type'] == TYPE_RESULT
|
||||
assert msg['success']
|
||||
assert msg['result']['content_type'] == 'image/jpeg'
|
||||
assert msg['result']['content'] == \
|
||||
|
@ -9,7 +9,7 @@ from homeassistant.setup import async_setup_component
|
||||
from homeassistant.components.frontend import (
|
||||
DOMAIN, CONF_JS_VERSION, CONF_THEMES, CONF_EXTRA_HTML_URL,
|
||||
CONF_EXTRA_HTML_URL_ES5)
|
||||
from homeassistant.components import websocket_api as wapi
|
||||
from homeassistant.components.websocket_api.const import TYPE_RESULT
|
||||
|
||||
from tests.common import mock_coro
|
||||
|
||||
@ -213,7 +213,7 @@ async def test_missing_themes(hass, hass_ws_client):
|
||||
msg = await client.receive_json()
|
||||
|
||||
assert msg['id'] == 5
|
||||
assert msg['type'] == wapi.TYPE_RESULT
|
||||
assert msg['type'] == TYPE_RESULT
|
||||
assert msg['success']
|
||||
assert msg['result']['default_theme'] == 'default'
|
||||
assert msg['result']['themes'] == {}
|
||||
@ -252,7 +252,7 @@ async def test_get_panels(hass, hass_ws_client):
|
||||
msg = await client.receive_json()
|
||||
|
||||
assert msg['id'] == 5
|
||||
assert msg['type'] == wapi.TYPE_RESULT
|
||||
assert msg['type'] == TYPE_RESULT
|
||||
assert msg['success']
|
||||
assert msg['result']['map']['component_name'] == 'map'
|
||||
assert msg['result']['map']['url_path'] == 'map'
|
||||
@ -275,7 +275,7 @@ async def test_get_translations(hass, hass_ws_client):
|
||||
msg = await client.receive_json()
|
||||
|
||||
assert msg['id'] == 5
|
||||
assert msg['type'] == wapi.TYPE_RESULT
|
||||
assert msg['type'] == TYPE_RESULT
|
||||
assert msg['success']
|
||||
assert msg['result'] == {'resources': {'lang': 'nl'}}
|
||||
|
||||
|
@ -3,7 +3,7 @@ from unittest.mock import patch
|
||||
|
||||
from homeassistant.exceptions import HomeAssistantError
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.components import websocket_api as wapi
|
||||
from homeassistant.components.websocket_api.const import TYPE_RESULT
|
||||
|
||||
|
||||
async def test_deprecated_lovelace_ui(hass, hass_ws_client):
|
||||
@ -20,7 +20,7 @@ async def test_deprecated_lovelace_ui(hass, hass_ws_client):
|
||||
msg = await client.receive_json()
|
||||
|
||||
assert msg['id'] == 5
|
||||
assert msg['type'] == wapi.TYPE_RESULT
|
||||
assert msg['type'] == TYPE_RESULT
|
||||
assert msg['success']
|
||||
assert msg['result'] == {'hello': 'world'}
|
||||
|
||||
@ -39,7 +39,7 @@ async def test_deprecated_lovelace_ui_not_found(hass, hass_ws_client):
|
||||
msg = await client.receive_json()
|
||||
|
||||
assert msg['id'] == 5
|
||||
assert msg['type'] == wapi.TYPE_RESULT
|
||||
assert msg['type'] == TYPE_RESULT
|
||||
assert msg['success'] is False
|
||||
assert msg['error']['code'] == 'file_not_found'
|
||||
|
||||
@ -58,7 +58,7 @@ async def test_deprecated_lovelace_ui_load_err(hass, hass_ws_client):
|
||||
msg = await client.receive_json()
|
||||
|
||||
assert msg['id'] == 5
|
||||
assert msg['type'] == wapi.TYPE_RESULT
|
||||
assert msg['type'] == TYPE_RESULT
|
||||
assert msg['success'] is False
|
||||
assert msg['error']['code'] == 'load_error'
|
||||
|
||||
@ -77,7 +77,7 @@ async def test_lovelace_ui(hass, hass_ws_client):
|
||||
msg = await client.receive_json()
|
||||
|
||||
assert msg['id'] == 5
|
||||
assert msg['type'] == wapi.TYPE_RESULT
|
||||
assert msg['type'] == TYPE_RESULT
|
||||
assert msg['success']
|
||||
assert msg['result'] == {'hello': 'world'}
|
||||
|
||||
@ -96,7 +96,7 @@ async def test_lovelace_ui_not_found(hass, hass_ws_client):
|
||||
msg = await client.receive_json()
|
||||
|
||||
assert msg['id'] == 5
|
||||
assert msg['type'] == wapi.TYPE_RESULT
|
||||
assert msg['type'] == TYPE_RESULT
|
||||
assert msg['success'] is False
|
||||
assert msg['error']['code'] == 'file_not_found'
|
||||
|
||||
@ -115,6 +115,6 @@ async def test_lovelace_ui_load_err(hass, hass_ws_client):
|
||||
msg = await client.receive_json()
|
||||
|
||||
assert msg['id'] == 5
|
||||
assert msg['type'] == wapi.TYPE_RESULT
|
||||
assert msg['type'] == TYPE_RESULT
|
||||
assert msg['success'] is False
|
||||
assert msg['error']['code'] == 'load_error'
|
||||
|
@ -3,7 +3,7 @@ import base64
|
||||
from unittest.mock import patch
|
||||
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.components import websocket_api
|
||||
from homeassistant.components.websocket_api.const import TYPE_RESULT
|
||||
|
||||
from tests.common import mock_coro
|
||||
|
||||
@ -30,7 +30,7 @@ async def test_get_panels(hass, hass_ws_client):
|
||||
msg = await client.receive_json()
|
||||
|
||||
assert msg['id'] == 5
|
||||
assert msg['type'] == websocket_api.TYPE_RESULT
|
||||
assert msg['type'] == TYPE_RESULT
|
||||
assert msg['success']
|
||||
assert msg['result']['content_type'] == 'image/jpeg'
|
||||
assert msg['result']['content'] == \
|
||||
|
@ -1,5 +1,5 @@
|
||||
"""The tests for the persistent notification component."""
|
||||
from homeassistant.components import websocket_api
|
||||
from homeassistant.components.websocket_api.const import TYPE_RESULT
|
||||
from homeassistant.setup import setup_component, async_setup_component
|
||||
import homeassistant.components.persistent_notification as pn
|
||||
|
||||
@ -151,7 +151,7 @@ async def test_ws_get_notifications(hass, hass_ws_client):
|
||||
})
|
||||
msg = await client.receive_json()
|
||||
assert msg['id'] == 5
|
||||
assert msg['type'] == websocket_api.TYPE_RESULT
|
||||
assert msg['type'] == TYPE_RESULT
|
||||
assert msg['success']
|
||||
notifications = msg['result']
|
||||
assert len(notifications) == 0
|
||||
@ -165,7 +165,7 @@ async def test_ws_get_notifications(hass, hass_ws_client):
|
||||
})
|
||||
msg = await client.receive_json()
|
||||
assert msg['id'] == 6
|
||||
assert msg['type'] == websocket_api.TYPE_RESULT
|
||||
assert msg['type'] == TYPE_RESULT
|
||||
assert msg['success']
|
||||
notifications = msg['result']
|
||||
assert len(notifications) == 1
|
||||
|
@ -1,558 +0,0 @@
|
||||
"""Tests for the Home Assistant Websocket API."""
|
||||
import asyncio
|
||||
from unittest.mock import patch, Mock
|
||||
|
||||
from aiohttp import WSMsgType
|
||||
from async_timeout import timeout
|
||||
import pytest
|
||||
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.components import websocket_api as wapi
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import mock_coro, async_mock_service
|
||||
|
||||
API_PASSWORD = 'test1234'
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def websocket_client(hass, hass_ws_client):
|
||||
"""Create a websocket client."""
|
||||
return hass.loop.run_until_complete(hass_ws_client(hass))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def no_auth_websocket_client(hass, loop, aiohttp_client):
|
||||
"""Websocket connection that requires authentication."""
|
||||
assert loop.run_until_complete(
|
||||
async_setup_component(hass, 'websocket_api', {
|
||||
'http': {
|
||||
'api_password': API_PASSWORD
|
||||
}
|
||||
}))
|
||||
|
||||
client = loop.run_until_complete(aiohttp_client(hass.http.app))
|
||||
ws = loop.run_until_complete(client.ws_connect(wapi.URL))
|
||||
|
||||
auth_ok = loop.run_until_complete(ws.receive_json())
|
||||
assert auth_ok['type'] == wapi.TYPE_AUTH_REQUIRED
|
||||
|
||||
yield ws
|
||||
|
||||
if not ws.closed:
|
||||
loop.run_until_complete(ws.close())
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_low_queue():
|
||||
"""Mock a low queue."""
|
||||
with patch.object(wapi, 'MAX_PENDING_MSG', 5):
|
||||
yield
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_auth_via_msg(no_auth_websocket_client):
|
||||
"""Test authenticating."""
|
||||
yield from no_auth_websocket_client.send_json({
|
||||
'type': wapi.TYPE_AUTH,
|
||||
'api_password': API_PASSWORD
|
||||
})
|
||||
|
||||
msg = yield from no_auth_websocket_client.receive_json()
|
||||
|
||||
assert msg['type'] == wapi.TYPE_AUTH_OK
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_auth_via_msg_incorrect_pass(no_auth_websocket_client):
|
||||
"""Test authenticating."""
|
||||
with patch('homeassistant.components.websocket_api.process_wrong_login',
|
||||
return_value=mock_coro()) as mock_process_wrong_login:
|
||||
yield from no_auth_websocket_client.send_json({
|
||||
'type': wapi.TYPE_AUTH,
|
||||
'api_password': API_PASSWORD + 'wrong'
|
||||
})
|
||||
|
||||
msg = yield from no_auth_websocket_client.receive_json()
|
||||
|
||||
assert mock_process_wrong_login.called
|
||||
assert msg['type'] == wapi.TYPE_AUTH_INVALID
|
||||
assert msg['message'] == 'Invalid access token or password'
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_pre_auth_only_auth_allowed(no_auth_websocket_client):
|
||||
"""Verify that before authentication, only auth messages are allowed."""
|
||||
yield from no_auth_websocket_client.send_json({
|
||||
'type': wapi.TYPE_CALL_SERVICE,
|
||||
'domain': 'domain_test',
|
||||
'service': 'test_service',
|
||||
'service_data': {
|
||||
'hello': 'world'
|
||||
}
|
||||
})
|
||||
|
||||
msg = yield from no_auth_websocket_client.receive_json()
|
||||
|
||||
assert msg['type'] == wapi.TYPE_AUTH_INVALID
|
||||
assert msg['message'].startswith('Message incorrectly formatted')
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_invalid_message_format(websocket_client):
|
||||
"""Test sending invalid JSON."""
|
||||
yield from websocket_client.send_json({'type': 5})
|
||||
|
||||
msg = yield from websocket_client.receive_json()
|
||||
|
||||
assert msg['type'] == wapi.TYPE_RESULT
|
||||
error = msg['error']
|
||||
assert error['code'] == wapi.ERR_INVALID_FORMAT
|
||||
assert error['message'].startswith('Message incorrectly formatted')
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_invalid_json(websocket_client):
|
||||
"""Test sending invalid JSON."""
|
||||
yield from websocket_client.send_str('this is not JSON')
|
||||
|
||||
msg = yield from websocket_client.receive()
|
||||
|
||||
assert msg.type == WSMsgType.close
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_quiting_hass(hass, websocket_client):
|
||||
"""Test sending invalid JSON."""
|
||||
with patch.object(hass.loop, 'stop'):
|
||||
yield from hass.async_stop()
|
||||
|
||||
msg = yield from websocket_client.receive()
|
||||
|
||||
assert msg.type == WSMsgType.CLOSE
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_call_service(hass, websocket_client):
|
||||
"""Test call service command."""
|
||||
calls = []
|
||||
|
||||
@callback
|
||||
def service_call(call):
|
||||
calls.append(call)
|
||||
|
||||
hass.services.async_register('domain_test', 'test_service', service_call)
|
||||
|
||||
yield from websocket_client.send_json({
|
||||
'id': 5,
|
||||
'type': wapi.TYPE_CALL_SERVICE,
|
||||
'domain': 'domain_test',
|
||||
'service': 'test_service',
|
||||
'service_data': {
|
||||
'hello': 'world'
|
||||
}
|
||||
})
|
||||
|
||||
msg = yield from websocket_client.receive_json()
|
||||
assert msg['id'] == 5
|
||||
assert msg['type'] == wapi.TYPE_RESULT
|
||||
assert msg['success']
|
||||
|
||||
assert len(calls) == 1
|
||||
call = calls[0]
|
||||
|
||||
assert call.domain == 'domain_test'
|
||||
assert call.service == 'test_service'
|
||||
assert call.data == {'hello': 'world'}
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_subscribe_unsubscribe_events(hass, websocket_client):
|
||||
"""Test subscribe/unsubscribe events command."""
|
||||
init_count = sum(hass.bus.async_listeners().values())
|
||||
|
||||
yield from websocket_client.send_json({
|
||||
'id': 5,
|
||||
'type': wapi.TYPE_SUBSCRIBE_EVENTS,
|
||||
'event_type': 'test_event'
|
||||
})
|
||||
|
||||
msg = yield from websocket_client.receive_json()
|
||||
assert msg['id'] == 5
|
||||
assert msg['type'] == wapi.TYPE_RESULT
|
||||
assert msg['success']
|
||||
|
||||
# Verify we have a new listener
|
||||
assert sum(hass.bus.async_listeners().values()) == init_count + 1
|
||||
|
||||
hass.bus.async_fire('ignore_event')
|
||||
hass.bus.async_fire('test_event', {'hello': 'world'})
|
||||
hass.bus.async_fire('ignore_event')
|
||||
|
||||
with timeout(3, loop=hass.loop):
|
||||
msg = yield from websocket_client.receive_json()
|
||||
|
||||
assert msg['id'] == 5
|
||||
assert msg['type'] == wapi.TYPE_EVENT
|
||||
event = msg['event']
|
||||
|
||||
assert event['event_type'] == 'test_event'
|
||||
assert event['data'] == {'hello': 'world'}
|
||||
assert event['origin'] == 'LOCAL'
|
||||
|
||||
yield from websocket_client.send_json({
|
||||
'id': 6,
|
||||
'type': wapi.TYPE_UNSUBSCRIBE_EVENTS,
|
||||
'subscription': 5
|
||||
})
|
||||
|
||||
msg = yield from websocket_client.receive_json()
|
||||
assert msg['id'] == 6
|
||||
assert msg['type'] == wapi.TYPE_RESULT
|
||||
assert msg['success']
|
||||
|
||||
# Check our listener got unsubscribed
|
||||
assert sum(hass.bus.async_listeners().values()) == init_count
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_get_states(hass, websocket_client):
|
||||
"""Test get_states command."""
|
||||
hass.states.async_set('greeting.hello', 'world')
|
||||
hass.states.async_set('greeting.bye', 'universe')
|
||||
|
||||
yield from websocket_client.send_json({
|
||||
'id': 5,
|
||||
'type': wapi.TYPE_GET_STATES,
|
||||
})
|
||||
|
||||
msg = yield from websocket_client.receive_json()
|
||||
assert msg['id'] == 5
|
||||
assert msg['type'] == wapi.TYPE_RESULT
|
||||
assert msg['success']
|
||||
|
||||
states = []
|
||||
for state in hass.states.async_all():
|
||||
state = state.as_dict()
|
||||
state['last_changed'] = state['last_changed'].isoformat()
|
||||
state['last_updated'] = state['last_updated'].isoformat()
|
||||
states.append(state)
|
||||
|
||||
assert msg['result'] == states
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_get_services(hass, websocket_client):
|
||||
"""Test get_services command."""
|
||||
yield from websocket_client.send_json({
|
||||
'id': 5,
|
||||
'type': wapi.TYPE_GET_SERVICES,
|
||||
})
|
||||
|
||||
msg = yield from websocket_client.receive_json()
|
||||
assert msg['id'] == 5
|
||||
assert msg['type'] == wapi.TYPE_RESULT
|
||||
assert msg['success']
|
||||
assert msg['result'] == hass.services.async_services()
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_get_config(hass, websocket_client):
|
||||
"""Test get_config command."""
|
||||
yield from websocket_client.send_json({
|
||||
'id': 5,
|
||||
'type': wapi.TYPE_GET_CONFIG,
|
||||
})
|
||||
|
||||
msg = yield from websocket_client.receive_json()
|
||||
assert msg['id'] == 5
|
||||
assert msg['type'] == wapi.TYPE_RESULT
|
||||
assert msg['success']
|
||||
|
||||
if 'components' in msg['result']:
|
||||
msg['result']['components'] = set(msg['result']['components'])
|
||||
if 'whitelist_external_dirs' in msg['result']:
|
||||
msg['result']['whitelist_external_dirs'] = \
|
||||
set(msg['result']['whitelist_external_dirs'])
|
||||
|
||||
assert msg['result'] == hass.config.as_dict()
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_ping(websocket_client):
|
||||
"""Test get_panels command."""
|
||||
yield from websocket_client.send_json({
|
||||
'id': 5,
|
||||
'type': wapi.TYPE_PING,
|
||||
})
|
||||
|
||||
msg = yield from websocket_client.receive_json()
|
||||
assert msg['id'] == 5
|
||||
assert msg['type'] == wapi.TYPE_PONG
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_pending_msg_overflow(hass, mock_low_queue, websocket_client):
|
||||
"""Test get_panels command."""
|
||||
for idx in range(10):
|
||||
yield from websocket_client.send_json({
|
||||
'id': idx + 1,
|
||||
'type': wapi.TYPE_PING,
|
||||
})
|
||||
msg = yield from websocket_client.receive()
|
||||
assert msg.type == WSMsgType.close
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_unknown_command(websocket_client):
|
||||
"""Test get_panels command."""
|
||||
yield from websocket_client.send_json({
|
||||
'id': 5,
|
||||
'type': 'unknown_command',
|
||||
})
|
||||
|
||||
msg = yield from websocket_client.receive_json()
|
||||
assert not msg['success']
|
||||
assert msg['error']['code'] == wapi.ERR_UNKNOWN_COMMAND
|
||||
|
||||
|
||||
async def test_auth_active_with_token(hass, aiohttp_client, hass_access_token):
|
||||
"""Test authenticating with a token."""
|
||||
assert await async_setup_component(hass, 'websocket_api', {
|
||||
'http': {
|
||||
'api_password': API_PASSWORD
|
||||
}
|
||||
})
|
||||
|
||||
client = await aiohttp_client(hass.http.app)
|
||||
|
||||
async with client.ws_connect(wapi.URL) as ws:
|
||||
with patch('homeassistant.auth.AuthManager.active') as auth_active:
|
||||
auth_active.return_value = True
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
|
||||
|
||||
await ws.send_json({
|
||||
'type': wapi.TYPE_AUTH,
|
||||
'access_token': hass_access_token
|
||||
})
|
||||
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg['type'] == wapi.TYPE_AUTH_OK
|
||||
|
||||
|
||||
async def test_auth_active_user_inactive(hass, aiohttp_client,
|
||||
hass_access_token):
|
||||
"""Test authenticating with a token."""
|
||||
refresh_token = await hass.auth.async_validate_access_token(
|
||||
hass_access_token)
|
||||
refresh_token.user.is_active = False
|
||||
assert await async_setup_component(hass, 'websocket_api', {
|
||||
'http': {
|
||||
'api_password': API_PASSWORD
|
||||
}
|
||||
})
|
||||
|
||||
client = await aiohttp_client(hass.http.app)
|
||||
|
||||
async with client.ws_connect(wapi.URL) as ws:
|
||||
with patch('homeassistant.auth.AuthManager.active') as auth_active:
|
||||
auth_active.return_value = True
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
|
||||
|
||||
await ws.send_json({
|
||||
'type': wapi.TYPE_AUTH,
|
||||
'access_token': hass_access_token
|
||||
})
|
||||
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg['type'] == wapi.TYPE_AUTH_INVALID
|
||||
|
||||
|
||||
async def test_auth_active_with_password_not_allow(hass, aiohttp_client):
|
||||
"""Test authenticating with a token."""
|
||||
assert await async_setup_component(hass, 'websocket_api', {
|
||||
'http': {
|
||||
'api_password': API_PASSWORD
|
||||
}
|
||||
})
|
||||
|
||||
client = await aiohttp_client(hass.http.app)
|
||||
|
||||
async with client.ws_connect(wapi.URL) as ws:
|
||||
with patch('homeassistant.auth.AuthManager.active',
|
||||
return_value=True):
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
|
||||
|
||||
await ws.send_json({
|
||||
'type': wapi.TYPE_AUTH,
|
||||
'api_password': API_PASSWORD
|
||||
})
|
||||
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg['type'] == wapi.TYPE_AUTH_INVALID
|
||||
|
||||
|
||||
async def test_auth_legacy_support_with_password(hass, aiohttp_client):
|
||||
"""Test authenticating with a token."""
|
||||
assert await async_setup_component(hass, 'websocket_api', {
|
||||
'http': {
|
||||
'api_password': API_PASSWORD
|
||||
}
|
||||
})
|
||||
|
||||
client = await aiohttp_client(hass.http.app)
|
||||
|
||||
async with client.ws_connect(wapi.URL) as ws:
|
||||
with patch('homeassistant.auth.AuthManager.active',
|
||||
return_value=True),\
|
||||
patch('homeassistant.auth.AuthManager.support_legacy',
|
||||
return_value=True):
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
|
||||
|
||||
await ws.send_json({
|
||||
'type': wapi.TYPE_AUTH,
|
||||
'api_password': API_PASSWORD
|
||||
})
|
||||
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg['type'] == wapi.TYPE_AUTH_OK
|
||||
|
||||
|
||||
async def test_auth_with_invalid_token(hass, aiohttp_client):
|
||||
"""Test authenticating with a token."""
|
||||
assert await async_setup_component(hass, 'websocket_api', {
|
||||
'http': {
|
||||
'api_password': API_PASSWORD
|
||||
}
|
||||
})
|
||||
|
||||
client = await aiohttp_client(hass.http.app)
|
||||
|
||||
async with client.ws_connect(wapi.URL) as ws:
|
||||
with patch('homeassistant.auth.AuthManager.active') as auth_active:
|
||||
auth_active.return_value = True
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
|
||||
|
||||
await ws.send_json({
|
||||
'type': wapi.TYPE_AUTH,
|
||||
'access_token': 'incorrect'
|
||||
})
|
||||
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg['type'] == wapi.TYPE_AUTH_INVALID
|
||||
|
||||
|
||||
async def test_call_service_context_with_user(hass, aiohttp_client,
|
||||
hass_access_token):
|
||||
"""Test that the user is set in the service call context."""
|
||||
assert await async_setup_component(hass, 'websocket_api', {
|
||||
'http': {
|
||||
'api_password': API_PASSWORD
|
||||
}
|
||||
})
|
||||
|
||||
calls = async_mock_service(hass, 'domain_test', 'test_service')
|
||||
client = await aiohttp_client(hass.http.app)
|
||||
|
||||
async with client.ws_connect(wapi.URL) as ws:
|
||||
with patch('homeassistant.auth.AuthManager.active') as auth_active:
|
||||
auth_active.return_value = True
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
|
||||
|
||||
await ws.send_json({
|
||||
'type': wapi.TYPE_AUTH,
|
||||
'access_token': hass_access_token
|
||||
})
|
||||
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg['type'] == wapi.TYPE_AUTH_OK
|
||||
|
||||
await ws.send_json({
|
||||
'id': 5,
|
||||
'type': wapi.TYPE_CALL_SERVICE,
|
||||
'domain': 'domain_test',
|
||||
'service': 'test_service',
|
||||
'service_data': {
|
||||
'hello': 'world'
|
||||
}
|
||||
})
|
||||
|
||||
msg = await ws.receive_json()
|
||||
assert msg['success']
|
||||
|
||||
refresh_token = await hass.auth.async_validate_access_token(
|
||||
hass_access_token)
|
||||
|
||||
assert len(calls) == 1
|
||||
call = calls[0]
|
||||
assert call.domain == 'domain_test'
|
||||
assert call.service == 'test_service'
|
||||
assert call.data == {'hello': 'world'}
|
||||
assert call.context.user_id == refresh_token.user.id
|
||||
|
||||
|
||||
async def test_call_service_context_no_user(hass, aiohttp_client):
|
||||
"""Test that connection without user sets context."""
|
||||
assert await async_setup_component(hass, 'websocket_api', {
|
||||
'http': {
|
||||
'api_password': API_PASSWORD
|
||||
}
|
||||
})
|
||||
|
||||
calls = async_mock_service(hass, 'domain_test', 'test_service')
|
||||
client = await aiohttp_client(hass.http.app)
|
||||
|
||||
async with client.ws_connect(wapi.URL) as ws:
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
|
||||
|
||||
await ws.send_json({
|
||||
'type': wapi.TYPE_AUTH,
|
||||
'api_password': API_PASSWORD
|
||||
})
|
||||
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg['type'] == wapi.TYPE_AUTH_OK
|
||||
|
||||
await ws.send_json({
|
||||
'id': 5,
|
||||
'type': wapi.TYPE_CALL_SERVICE,
|
||||
'domain': 'domain_test',
|
||||
'service': 'test_service',
|
||||
'service_data': {
|
||||
'hello': 'world'
|
||||
}
|
||||
})
|
||||
|
||||
msg = await ws.receive_json()
|
||||
assert msg['success']
|
||||
|
||||
assert len(calls) == 1
|
||||
call = calls[0]
|
||||
assert call.domain == 'domain_test'
|
||||
assert call.service == 'test_service'
|
||||
assert call.data == {'hello': 'world'}
|
||||
assert call.context.user_id is None
|
||||
|
||||
|
||||
async def test_handler_failing(hass, websocket_client):
|
||||
"""Test a command that raises."""
|
||||
hass.components.websocket_api.async_register_command(
|
||||
'bla', Mock(side_effect=TypeError),
|
||||
wapi.BASE_COMMAND_MESSAGE_SCHEMA.extend({'type': 'bla'}))
|
||||
await websocket_client.send_json({
|
||||
'id': 5,
|
||||
'type': 'bla',
|
||||
})
|
||||
|
||||
msg = await websocket_client.receive_json()
|
||||
assert msg['id'] == 5
|
||||
assert msg['type'] == wapi.TYPE_RESULT
|
||||
assert not msg['success']
|
||||
assert msg['error']['code'] == wapi.ERR_UNKNOWN_ERROR
|
2
tests/components/websocket_api/__init__.py
Normal file
2
tests/components/websocket_api/__init__.py
Normal file
@ -0,0 +1,2 @@
|
||||
"""Tests for the websocket API."""
|
||||
API_PASSWORD = 'test1234'
|
35
tests/components/websocket_api/conftest.py
Normal file
35
tests/components/websocket_api/conftest.py
Normal file
@ -0,0 +1,35 @@
|
||||
"""Fixtures for websocket tests."""
|
||||
import pytest
|
||||
|
||||
from homeassistant.setup import async_setup_component
|
||||
from homeassistant.components import websocket_api as wapi
|
||||
|
||||
from . import API_PASSWORD
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def websocket_client(hass, hass_ws_client):
|
||||
"""Create a websocket client."""
|
||||
return hass.loop.run_until_complete(hass_ws_client(hass))
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def no_auth_websocket_client(hass, loop, aiohttp_client):
|
||||
"""Websocket connection that requires authentication."""
|
||||
assert loop.run_until_complete(
|
||||
async_setup_component(hass, 'websocket_api', {
|
||||
'http': {
|
||||
'api_password': API_PASSWORD
|
||||
}
|
||||
}))
|
||||
|
||||
client = loop.run_until_complete(aiohttp_client(hass.http.app))
|
||||
ws = loop.run_until_complete(client.ws_connect(wapi.URL))
|
||||
|
||||
auth_ok = loop.run_until_complete(ws.receive_json())
|
||||
assert auth_ok['type'] == wapi.TYPE_AUTH_REQUIRED
|
||||
|
||||
yield ws
|
||||
|
||||
if not ws.closed:
|
||||
loop.run_until_complete(ws.close())
|
186
tests/components/websocket_api/test_auth.py
Normal file
186
tests/components/websocket_api/test_auth.py
Normal file
@ -0,0 +1,186 @@
|
||||
"""Test auth of websocket API."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from homeassistant.components import websocket_api as wapi
|
||||
from homeassistant.components.websocket_api import commands
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import mock_coro
|
||||
|
||||
from . import API_PASSWORD
|
||||
|
||||
|
||||
async def test_auth_via_msg(no_auth_websocket_client):
|
||||
"""Test authenticating."""
|
||||
await no_auth_websocket_client.send_json({
|
||||
'type': wapi.TYPE_AUTH,
|
||||
'api_password': API_PASSWORD
|
||||
})
|
||||
|
||||
msg = await no_auth_websocket_client.receive_json()
|
||||
|
||||
assert msg['type'] == wapi.TYPE_AUTH_OK
|
||||
|
||||
|
||||
async def test_auth_via_msg_incorrect_pass(no_auth_websocket_client):
|
||||
"""Test authenticating."""
|
||||
with patch('homeassistant.components.websocket_api.process_wrong_login',
|
||||
return_value=mock_coro()) as mock_process_wrong_login:
|
||||
await no_auth_websocket_client.send_json({
|
||||
'type': wapi.TYPE_AUTH,
|
||||
'api_password': API_PASSWORD + 'wrong'
|
||||
})
|
||||
|
||||
msg = await no_auth_websocket_client.receive_json()
|
||||
|
||||
assert mock_process_wrong_login.called
|
||||
assert msg['type'] == wapi.TYPE_AUTH_INVALID
|
||||
assert msg['message'] == 'Invalid access token or password'
|
||||
|
||||
|
||||
async def test_pre_auth_only_auth_allowed(no_auth_websocket_client):
|
||||
"""Verify that before authentication, only auth messages are allowed."""
|
||||
await no_auth_websocket_client.send_json({
|
||||
'type': commands.TYPE_CALL_SERVICE,
|
||||
'domain': 'domain_test',
|
||||
'service': 'test_service',
|
||||
'service_data': {
|
||||
'hello': 'world'
|
||||
}
|
||||
})
|
||||
|
||||
msg = await no_auth_websocket_client.receive_json()
|
||||
|
||||
assert msg['type'] == wapi.TYPE_AUTH_INVALID
|
||||
assert msg['message'].startswith('Message incorrectly formatted')
|
||||
|
||||
|
||||
async def test_auth_active_with_token(hass, aiohttp_client, hass_access_token):
|
||||
"""Test authenticating with a token."""
|
||||
assert await async_setup_component(hass, 'websocket_api', {
|
||||
'http': {
|
||||
'api_password': API_PASSWORD
|
||||
}
|
||||
})
|
||||
|
||||
client = await aiohttp_client(hass.http.app)
|
||||
|
||||
async with client.ws_connect(wapi.URL) as ws:
|
||||
with patch('homeassistant.auth.AuthManager.active') as auth_active:
|
||||
auth_active.return_value = True
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
|
||||
|
||||
await ws.send_json({
|
||||
'type': wapi.TYPE_AUTH,
|
||||
'access_token': hass_access_token
|
||||
})
|
||||
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg['type'] == wapi.TYPE_AUTH_OK
|
||||
|
||||
|
||||
async def test_auth_active_user_inactive(hass, aiohttp_client,
|
||||
hass_access_token):
|
||||
"""Test authenticating with a token."""
|
||||
refresh_token = await hass.auth.async_validate_access_token(
|
||||
hass_access_token)
|
||||
refresh_token.user.is_active = False
|
||||
assert await async_setup_component(hass, 'websocket_api', {
|
||||
'http': {
|
||||
'api_password': API_PASSWORD
|
||||
}
|
||||
})
|
||||
|
||||
client = await aiohttp_client(hass.http.app)
|
||||
|
||||
async with client.ws_connect(wapi.URL) as ws:
|
||||
with patch('homeassistant.auth.AuthManager.active') as auth_active:
|
||||
auth_active.return_value = True
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
|
||||
|
||||
await ws.send_json({
|
||||
'type': wapi.TYPE_AUTH,
|
||||
'access_token': hass_access_token
|
||||
})
|
||||
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg['type'] == wapi.TYPE_AUTH_INVALID
|
||||
|
||||
|
||||
async def test_auth_active_with_password_not_allow(hass, aiohttp_client):
|
||||
"""Test authenticating with a token."""
|
||||
assert await async_setup_component(hass, 'websocket_api', {
|
||||
'http': {
|
||||
'api_password': API_PASSWORD
|
||||
}
|
||||
})
|
||||
|
||||
client = await aiohttp_client(hass.http.app)
|
||||
|
||||
async with client.ws_connect(wapi.URL) as ws:
|
||||
with patch('homeassistant.auth.AuthManager.active',
|
||||
return_value=True):
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
|
||||
|
||||
await ws.send_json({
|
||||
'type': wapi.TYPE_AUTH,
|
||||
'api_password': API_PASSWORD
|
||||
})
|
||||
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg['type'] == wapi.TYPE_AUTH_INVALID
|
||||
|
||||
|
||||
async def test_auth_legacy_support_with_password(hass, aiohttp_client):
|
||||
"""Test authenticating with a token."""
|
||||
assert await async_setup_component(hass, 'websocket_api', {
|
||||
'http': {
|
||||
'api_password': API_PASSWORD
|
||||
}
|
||||
})
|
||||
|
||||
client = await aiohttp_client(hass.http.app)
|
||||
|
||||
async with client.ws_connect(wapi.URL) as ws:
|
||||
with patch('homeassistant.auth.AuthManager.active',
|
||||
return_value=True),\
|
||||
patch('homeassistant.auth.AuthManager.support_legacy',
|
||||
return_value=True):
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
|
||||
|
||||
await ws.send_json({
|
||||
'type': wapi.TYPE_AUTH,
|
||||
'api_password': API_PASSWORD
|
||||
})
|
||||
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg['type'] == wapi.TYPE_AUTH_OK
|
||||
|
||||
|
||||
async def test_auth_with_invalid_token(hass, aiohttp_client):
|
||||
"""Test authenticating with a token."""
|
||||
assert await async_setup_component(hass, 'websocket_api', {
|
||||
'http': {
|
||||
'api_password': API_PASSWORD
|
||||
}
|
||||
})
|
||||
|
||||
client = await aiohttp_client(hass.http.app)
|
||||
|
||||
async with client.ws_connect(wapi.URL) as ws:
|
||||
with patch('homeassistant.auth.AuthManager.active') as auth_active:
|
||||
auth_active.return_value = True
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
|
||||
|
||||
await ws.send_json({
|
||||
'type': wapi.TYPE_AUTH,
|
||||
'access_token': 'incorrect'
|
||||
})
|
||||
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg['type'] == wapi.TYPE_AUTH_INVALID
|
260
tests/components/websocket_api/test_commands.py
Normal file
260
tests/components/websocket_api/test_commands.py
Normal file
@ -0,0 +1,260 @@
|
||||
"""Tests for WebSocket API commands."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from async_timeout import timeout
|
||||
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.components import websocket_api as wapi
|
||||
from homeassistant.components.websocket_api import const, commands
|
||||
from homeassistant.setup import async_setup_component
|
||||
|
||||
from tests.common import async_mock_service
|
||||
|
||||
from . import API_PASSWORD
|
||||
|
||||
|
||||
async def test_call_service(hass, websocket_client):
|
||||
"""Test call service command."""
|
||||
calls = []
|
||||
|
||||
@callback
|
||||
def service_call(call):
|
||||
calls.append(call)
|
||||
|
||||
hass.services.async_register('domain_test', 'test_service', service_call)
|
||||
|
||||
await websocket_client.send_json({
|
||||
'id': 5,
|
||||
'type': commands.TYPE_CALL_SERVICE,
|
||||
'domain': 'domain_test',
|
||||
'service': 'test_service',
|
||||
'service_data': {
|
||||
'hello': 'world'
|
||||
}
|
||||
})
|
||||
|
||||
msg = await websocket_client.receive_json()
|
||||
assert msg['id'] == 5
|
||||
assert msg['type'] == const.TYPE_RESULT
|
||||
assert msg['success']
|
||||
|
||||
assert len(calls) == 1
|
||||
call = calls[0]
|
||||
|
||||
assert call.domain == 'domain_test'
|
||||
assert call.service == 'test_service'
|
||||
assert call.data == {'hello': 'world'}
|
||||
|
||||
|
||||
async def test_subscribe_unsubscribe_events(hass, websocket_client):
|
||||
"""Test subscribe/unsubscribe events command."""
|
||||
init_count = sum(hass.bus.async_listeners().values())
|
||||
|
||||
await websocket_client.send_json({
|
||||
'id': 5,
|
||||
'type': commands.TYPE_SUBSCRIBE_EVENTS,
|
||||
'event_type': 'test_event'
|
||||
})
|
||||
|
||||
msg = await websocket_client.receive_json()
|
||||
assert msg['id'] == 5
|
||||
assert msg['type'] == const.TYPE_RESULT
|
||||
assert msg['success']
|
||||
|
||||
# Verify we have a new listener
|
||||
assert sum(hass.bus.async_listeners().values()) == init_count + 1
|
||||
|
||||
hass.bus.async_fire('ignore_event')
|
||||
hass.bus.async_fire('test_event', {'hello': 'world'})
|
||||
hass.bus.async_fire('ignore_event')
|
||||
|
||||
with timeout(3, loop=hass.loop):
|
||||
msg = await websocket_client.receive_json()
|
||||
|
||||
assert msg['id'] == 5
|
||||
assert msg['type'] == commands.TYPE_EVENT
|
||||
event = msg['event']
|
||||
|
||||
assert event['event_type'] == 'test_event'
|
||||
assert event['data'] == {'hello': 'world'}
|
||||
assert event['origin'] == 'LOCAL'
|
||||
|
||||
await websocket_client.send_json({
|
||||
'id': 6,
|
||||
'type': commands.TYPE_UNSUBSCRIBE_EVENTS,
|
||||
'subscription': 5
|
||||
})
|
||||
|
||||
msg = await websocket_client.receive_json()
|
||||
assert msg['id'] == 6
|
||||
assert msg['type'] == const.TYPE_RESULT
|
||||
assert msg['success']
|
||||
|
||||
# Check our listener got unsubscribed
|
||||
assert sum(hass.bus.async_listeners().values()) == init_count
|
||||
|
||||
|
||||
async def test_get_states(hass, websocket_client):
|
||||
"""Test get_states command."""
|
||||
hass.states.async_set('greeting.hello', 'world')
|
||||
hass.states.async_set('greeting.bye', 'universe')
|
||||
|
||||
await websocket_client.send_json({
|
||||
'id': 5,
|
||||
'type': commands.TYPE_GET_STATES,
|
||||
})
|
||||
|
||||
msg = await websocket_client.receive_json()
|
||||
assert msg['id'] == 5
|
||||
assert msg['type'] == const.TYPE_RESULT
|
||||
assert msg['success']
|
||||
|
||||
states = []
|
||||
for state in hass.states.async_all():
|
||||
state = state.as_dict()
|
||||
state['last_changed'] = state['last_changed'].isoformat()
|
||||
state['last_updated'] = state['last_updated'].isoformat()
|
||||
states.append(state)
|
||||
|
||||
assert msg['result'] == states
|
||||
|
||||
|
||||
async def test_get_services(hass, websocket_client):
|
||||
"""Test get_services command."""
|
||||
await websocket_client.send_json({
|
||||
'id': 5,
|
||||
'type': commands.TYPE_GET_SERVICES,
|
||||
})
|
||||
|
||||
msg = await websocket_client.receive_json()
|
||||
assert msg['id'] == 5
|
||||
assert msg['type'] == const.TYPE_RESULT
|
||||
assert msg['success']
|
||||
assert msg['result'] == hass.services.async_services()
|
||||
|
||||
|
||||
async def test_get_config(hass, websocket_client):
|
||||
"""Test get_config command."""
|
||||
await websocket_client.send_json({
|
||||
'id': 5,
|
||||
'type': commands.TYPE_GET_CONFIG,
|
||||
})
|
||||
|
||||
msg = await websocket_client.receive_json()
|
||||
assert msg['id'] == 5
|
||||
assert msg['type'] == const.TYPE_RESULT
|
||||
assert msg['success']
|
||||
|
||||
if 'components' in msg['result']:
|
||||
msg['result']['components'] = set(msg['result']['components'])
|
||||
if 'whitelist_external_dirs' in msg['result']:
|
||||
msg['result']['whitelist_external_dirs'] = \
|
||||
set(msg['result']['whitelist_external_dirs'])
|
||||
|
||||
assert msg['result'] == hass.config.as_dict()
|
||||
|
||||
|
||||
async def test_ping(websocket_client):
|
||||
"""Test get_panels command."""
|
||||
await websocket_client.send_json({
|
||||
'id': 5,
|
||||
'type': commands.TYPE_PING,
|
||||
})
|
||||
|
||||
msg = await websocket_client.receive_json()
|
||||
assert msg['id'] == 5
|
||||
assert msg['type'] == commands.TYPE_PONG
|
||||
|
||||
|
||||
async def test_call_service_context_with_user(hass, aiohttp_client,
|
||||
hass_access_token):
|
||||
"""Test that the user is set in the service call context."""
|
||||
assert await async_setup_component(hass, 'websocket_api', {
|
||||
'http': {
|
||||
'api_password': API_PASSWORD
|
||||
}
|
||||
})
|
||||
|
||||
calls = async_mock_service(hass, 'domain_test', 'test_service')
|
||||
client = await aiohttp_client(hass.http.app)
|
||||
|
||||
async with client.ws_connect(wapi.URL) as ws:
|
||||
with patch('homeassistant.auth.AuthManager.active') as auth_active:
|
||||
auth_active.return_value = True
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
|
||||
|
||||
await ws.send_json({
|
||||
'type': wapi.TYPE_AUTH,
|
||||
'access_token': hass_access_token
|
||||
})
|
||||
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg['type'] == wapi.TYPE_AUTH_OK
|
||||
|
||||
await ws.send_json({
|
||||
'id': 5,
|
||||
'type': commands.TYPE_CALL_SERVICE,
|
||||
'domain': 'domain_test',
|
||||
'service': 'test_service',
|
||||
'service_data': {
|
||||
'hello': 'world'
|
||||
}
|
||||
})
|
||||
|
||||
msg = await ws.receive_json()
|
||||
assert msg['success']
|
||||
|
||||
refresh_token = await hass.auth.async_validate_access_token(
|
||||
hass_access_token)
|
||||
|
||||
assert len(calls) == 1
|
||||
call = calls[0]
|
||||
assert call.domain == 'domain_test'
|
||||
assert call.service == 'test_service'
|
||||
assert call.data == {'hello': 'world'}
|
||||
assert call.context.user_id == refresh_token.user.id
|
||||
|
||||
|
||||
async def test_call_service_context_no_user(hass, aiohttp_client):
|
||||
"""Test that connection without user sets context."""
|
||||
assert await async_setup_component(hass, 'websocket_api', {
|
||||
'http': {
|
||||
'api_password': API_PASSWORD
|
||||
}
|
||||
})
|
||||
|
||||
calls = async_mock_service(hass, 'domain_test', 'test_service')
|
||||
client = await aiohttp_client(hass.http.app)
|
||||
|
||||
async with client.ws_connect(wapi.URL) as ws:
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
|
||||
|
||||
await ws.send_json({
|
||||
'type': wapi.TYPE_AUTH,
|
||||
'api_password': API_PASSWORD
|
||||
})
|
||||
|
||||
auth_msg = await ws.receive_json()
|
||||
assert auth_msg['type'] == wapi.TYPE_AUTH_OK
|
||||
|
||||
await ws.send_json({
|
||||
'id': 5,
|
||||
'type': commands.TYPE_CALL_SERVICE,
|
||||
'domain': 'domain_test',
|
||||
'service': 'test_service',
|
||||
'service_data': {
|
||||
'hello': 'world'
|
||||
}
|
||||
})
|
||||
|
||||
msg = await ws.receive_json()
|
||||
assert msg['success']
|
||||
|
||||
assert len(calls) == 1
|
||||
call = calls[0]
|
||||
assert call.domain == 'domain_test'
|
||||
assert call.service == 'test_service'
|
||||
assert call.data == {'hello': 'world'}
|
||||
assert call.context.user_id is None
|
92
tests/components/websocket_api/test_init.py
Normal file
92
tests/components/websocket_api/test_init.py
Normal file
@ -0,0 +1,92 @@
|
||||
"""Tests for the Home Assistant Websocket API."""
|
||||
import asyncio
|
||||
from unittest.mock import patch, Mock
|
||||
|
||||
from aiohttp import WSMsgType
|
||||
import pytest
|
||||
|
||||
from homeassistant.components import websocket_api as wapi
|
||||
from homeassistant.components.websocket_api import const, commands, messages
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_low_queue():
|
||||
"""Mock a low queue."""
|
||||
with patch.object(wapi, 'MAX_PENDING_MSG', 5):
|
||||
yield
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_invalid_message_format(websocket_client):
|
||||
"""Test sending invalid JSON."""
|
||||
yield from websocket_client.send_json({'type': 5})
|
||||
|
||||
msg = yield from websocket_client.receive_json()
|
||||
|
||||
assert msg['type'] == const.TYPE_RESULT
|
||||
error = msg['error']
|
||||
assert error['code'] == const.ERR_INVALID_FORMAT
|
||||
assert error['message'].startswith('Message incorrectly formatted')
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_invalid_json(websocket_client):
|
||||
"""Test sending invalid JSON."""
|
||||
yield from websocket_client.send_str('this is not JSON')
|
||||
|
||||
msg = yield from websocket_client.receive()
|
||||
|
||||
assert msg.type == WSMsgType.close
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_quiting_hass(hass, websocket_client):
|
||||
"""Test sending invalid JSON."""
|
||||
with patch.object(hass.loop, 'stop'):
|
||||
yield from hass.async_stop()
|
||||
|
||||
msg = yield from websocket_client.receive()
|
||||
|
||||
assert msg.type == WSMsgType.CLOSE
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_pending_msg_overflow(hass, mock_low_queue, websocket_client):
|
||||
"""Test get_panels command."""
|
||||
for idx in range(10):
|
||||
yield from websocket_client.send_json({
|
||||
'id': idx + 1,
|
||||
'type': commands.TYPE_PING,
|
||||
})
|
||||
msg = yield from websocket_client.receive()
|
||||
assert msg.type == WSMsgType.close
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def test_unknown_command(websocket_client):
|
||||
"""Test get_panels command."""
|
||||
yield from websocket_client.send_json({
|
||||
'id': 5,
|
||||
'type': 'unknown_command',
|
||||
})
|
||||
|
||||
msg = yield from websocket_client.receive_json()
|
||||
assert not msg['success']
|
||||
assert msg['error']['code'] == const.ERR_UNKNOWN_COMMAND
|
||||
|
||||
|
||||
async def test_handler_failing(hass, websocket_client):
|
||||
"""Test a command that raises."""
|
||||
hass.components.websocket_api.async_register_command(
|
||||
'bla', Mock(side_effect=TypeError),
|
||||
messages.BASE_COMMAND_MESSAGE_SCHEMA.extend({'type': 'bla'}))
|
||||
await websocket_client.send_json({
|
||||
'id': 5,
|
||||
'type': 'bla',
|
||||
})
|
||||
|
||||
msg = await websocket_client.receive_json()
|
||||
assert msg['id'] == 5
|
||||
assert msg['type'] == const.TYPE_RESULT
|
||||
assert not msg['success']
|
||||
assert msg['error']['code'] == const.ERR_UNKNOWN_ERROR
|
Loading…
x
Reference in New Issue
Block a user