Break up websocket component (#17003)

* Break up websocket component

* Lint
This commit is contained in:
Paulus Schoutsen 2018-10-01 11:21:00 +02:00 committed by GitHub
parent 9edf1e5151
commit 22a80cf733
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 1041 additions and 1003 deletions

View File

@ -452,27 +452,23 @@ class CameraMjpegStream(CameraView):
raise web.HTTPBadRequest()
@callback
def websocket_camera_thumbnail(hass, connection, msg):
@websocket_api.async_response
async def websocket_camera_thumbnail(hass, connection, msg):
"""Handle get camera thumbnail websocket command.
Async friendly.
"""
async def send_camera_still():
"""Send a camera still."""
try:
image = await async_get_image(hass, msg['entity_id'])
connection.send_message_outside(websocket_api.result_message(
msg['id'], {
'content_type': image.content_type,
'content': base64.b64encode(image.content).decode('utf-8')
}
))
except HomeAssistantError:
connection.send_message_outside(websocket_api.error_message(
msg['id'], 'image_fetch_failed', 'Unable to fetch image'))
hass.async_add_job(send_camera_still())
try:
image = await async_get_image(hass, msg['entity_id'])
connection.send_message_outside(websocket_api.result_message(
msg['id'], {
'content_type': image.content_type,
'content': base64.b64encode(image.content).decode('utf-8')
}
))
except HomeAssistantError:
connection.send_message_outside(websocket_api.error_message(
msg['id'], 'image_fetch_failed', 'Unable to fetch image'))
async def async_handle_snapshot_service(camera, service):

View File

@ -3,6 +3,7 @@ import voluptuous as vol
from homeassistant.core import callback
from homeassistant.components import websocket_api
from homeassistant.components.websocket_api.decorators import require_owner
WS_TYPE_LIST = 'config/auth/list'
@ -41,7 +42,7 @@ async def async_setup(hass):
@callback
@websocket_api.require_owner
@require_owner
def websocket_list(hass, connection, msg):
"""Return a list of users."""
async def send_users():
@ -55,7 +56,7 @@ def websocket_list(hass, connection, msg):
@callback
@websocket_api.require_owner
@require_owner
def websocket_delete(hass, connection, msg):
"""Delete a user."""
async def delete_user():
@ -82,7 +83,7 @@ def websocket_delete(hass, connection, msg):
@callback
@websocket_api.require_owner
@require_owner
def websocket_create(hass, connection, msg):
"""Create a user."""
async def create_user():

View File

@ -4,6 +4,7 @@ import voluptuous as vol
from homeassistant.auth.providers import homeassistant as auth_ha
from homeassistant.core import callback
from homeassistant.components import websocket_api
from homeassistant.components.websocket_api.decorators import require_owner
WS_TYPE_CREATE = 'config/auth_provider/homeassistant/create'
@ -55,7 +56,7 @@ def _get_provider(hass):
@callback
@websocket_api.require_owner
@require_owner
def websocket_create(hass, connection, msg):
"""Create credentials and attach to a user."""
async def create_creds():
@ -96,7 +97,7 @@ def websocket_create(hass, connection, msg):
@callback
@websocket_api.require_owner
@require_owner
def websocket_delete(hass, connection, msg):
"""Delete username and related credential."""
async def delete_creds():

View File

@ -4,6 +4,8 @@ import voluptuous as vol
from homeassistant.core import callback
from homeassistant.helpers.entity_registry import async_get_registry
from homeassistant.components import websocket_api
from homeassistant.components.websocket_api.const import ERR_NOT_FOUND
from homeassistant.components.websocket_api.decorators import async_response
from homeassistant.helpers import config_validation as cv
DEPENDENCIES = ['websocket_api']
@ -46,89 +48,77 @@ async def async_setup(hass):
return True
@callback
def websocket_list_entities(hass, connection, msg):
@async_response
async def websocket_list_entities(hass, connection, msg):
"""Handle list registry entries command.
Async friendly.
"""
async def retrieve_entities():
"""Get entities from registry."""
registry = await async_get_registry(hass)
connection.send_message_outside(websocket_api.result_message(
msg['id'], [{
'config_entry_id': entry.config_entry_id,
'device_id': entry.device_id,
'disabled_by': entry.disabled_by,
'entity_id': entry.entity_id,
'name': entry.name,
'platform': entry.platform,
} for entry in registry.entities.values()]
))
hass.async_add_job(retrieve_entities())
registry = await async_get_registry(hass)
connection.send_message_outside(websocket_api.result_message(
msg['id'], [{
'config_entry_id': entry.config_entry_id,
'device_id': entry.device_id,
'disabled_by': entry.disabled_by,
'entity_id': entry.entity_id,
'name': entry.name,
'platform': entry.platform,
} for entry in registry.entities.values()]
))
@callback
def websocket_get_entity(hass, connection, msg):
@async_response
async def websocket_get_entity(hass, connection, msg):
"""Handle get entity registry entry command.
Async friendly.
"""
async def retrieve_entity():
"""Get entity from registry."""
registry = await async_get_registry(hass)
entry = registry.entities.get(msg['entity_id'])
registry = await async_get_registry(hass)
entry = registry.entities.get(msg['entity_id'])
if entry is None:
connection.send_message_outside(websocket_api.error_message(
msg['id'], websocket_api.ERR_NOT_FOUND, 'Entity not found'))
return
if entry is None:
connection.send_message_outside(websocket_api.error_message(
msg['id'], ERR_NOT_FOUND, 'Entity not found'))
return
connection.send_message_outside(websocket_api.result_message(
msg['id'], _entry_dict(entry)
))
hass.async_add_job(retrieve_entity())
connection.send_message_outside(websocket_api.result_message(
msg['id'], _entry_dict(entry)
))
@callback
def websocket_update_entity(hass, connection, msg):
@async_response
async def websocket_update_entity(hass, connection, msg):
"""Handle get camera thumbnail websocket command.
Async friendly.
"""
async def update_entity():
"""Get entity from registry."""
registry = await async_get_registry(hass)
registry = await async_get_registry(hass)
if msg['entity_id'] not in registry.entities:
connection.send_message_outside(websocket_api.error_message(
msg['id'], websocket_api.ERR_NOT_FOUND, 'Entity not found'))
return
if msg['entity_id'] not in registry.entities:
connection.send_message_outside(websocket_api.error_message(
msg['id'], ERR_NOT_FOUND, 'Entity not found'))
return
changes = {}
changes = {}
if 'name' in msg:
changes['name'] = msg['name']
if 'name' in msg:
changes['name'] = msg['name']
if 'new_entity_id' in msg:
changes['new_entity_id'] = msg['new_entity_id']
if 'new_entity_id' in msg:
changes['new_entity_id'] = msg['new_entity_id']
try:
if changes:
entry = registry.async_update_entity(
msg['entity_id'], **changes)
except ValueError as err:
connection.send_message_outside(websocket_api.error_message(
msg['id'], 'invalid_info', str(err)
))
else:
connection.send_message_outside(websocket_api.result_message(
msg['id'], _entry_dict(entry)
))
hass.async_create_task(update_entity())
try:
if changes:
entry = registry.async_update_entity(
msg['entity_id'], **changes)
except ValueError as err:
connection.send_message_outside(websocket_api.error_message(
msg['id'], 'invalid_info', str(err)
))
else:
connection.send_message_outside(websocket_api.result_message(
msg['id'], _entry_dict(entry)
))
@callback

View File

@ -28,7 +28,6 @@ from homeassistant.const import (
SERVICE_TOGGLE, SERVICE_TURN_OFF, SERVICE_TURN_ON, SERVICE_VOLUME_DOWN,
SERVICE_VOLUME_MUTE, SERVICE_VOLUME_SET, SERVICE_VOLUME_UP, STATE_IDLE,
STATE_OFF, STATE_PLAYING, STATE_UNKNOWN)
from homeassistant.core import callback
from homeassistant.helpers.aiohttp_client import async_get_clientsession
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.config_validation import PLATFORM_SCHEMA # noqa
@ -865,8 +864,8 @@ class MediaPlayerImageView(HomeAssistantView):
body=data, content_type=content_type, headers=headers)
@callback
def websocket_handle_thumbnail(hass, connection, msg):
@websocket_api.async_response
async def websocket_handle_thumbnail(hass, connection, msg):
"""Handle get media player cover command.
Async friendly.
@ -879,20 +878,16 @@ def websocket_handle_thumbnail(hass, connection, msg):
msg['id'], 'entity_not_found', 'Entity not found'))
return
async def send_image():
"""Send image."""
data, content_type = await player.async_get_media_image()
data, content_type = await player.async_get_media_image()
if data is None:
connection.send_message_outside(websocket_api.error_message(
msg['id'], 'thumbnail_fetch_failed',
'Failed to fetch thumbnail'))
return
if data is None:
connection.send_message_outside(websocket_api.error_message(
msg['id'], 'thumbnail_fetch_failed',
'Failed to fetch thumbnail'))
return
connection.send_message_outside(websocket_api.result_message(
msg['id'], {
'content_type': content_type,
'content': base64.b64encode(data).decode('utf-8')
}))
hass.async_add_job(send_image())
connection.send_message_outside(websocket_api.result_message(
msg['id'], {
'content_type': content_type,
'content': base64.b64encode(data).decode('utf-8')
}))

View File

@ -7,7 +7,7 @@ https://developers.home-assistant.io/docs/external_api_websocket.html
import asyncio
from concurrent import futures
from contextlib import suppress
from functools import partial, wraps
from functools import partial
import json
import logging
@ -15,20 +15,18 @@ from aiohttp import web
import voluptuous as vol
from voluptuous.humanize import humanize_error
from homeassistant.const import (
MATCH_ALL, EVENT_TIME_CHANGED, EVENT_HOMEASSISTANT_STOP,
__version__)
from homeassistant.core import Context, callback, HomeAssistant
from homeassistant.const import EVENT_HOMEASSISTANT_STOP, __version__
from homeassistant.core import Context, callback
from homeassistant.loader import bind_hass
from homeassistant.helpers.json import JSONEncoder
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.service import async_get_all_descriptions
from homeassistant.components.http import HomeAssistantView
from homeassistant.components.http.auth import validate_password
from homeassistant.components.http.const import KEY_AUTHENTICATED
from homeassistant.components.http.ban import process_wrong_login, \
process_success_login
from . import commands, const, decorators, messages
DOMAIN = 'websocket_api'
URL = '/api/websocket'
@ -36,87 +34,32 @@ DEPENDENCIES = ('http',)
MAX_PENDING_MSG = 512
ERR_ID_REUSE = 1
ERR_INVALID_FORMAT = 2
ERR_NOT_FOUND = 3
ERR_UNKNOWN_COMMAND = 4
ERR_UNKNOWN_ERROR = 5
TYPE_AUTH = 'auth'
TYPE_AUTH_INVALID = 'auth_invalid'
TYPE_AUTH_OK = 'auth_ok'
TYPE_AUTH_REQUIRED = 'auth_required'
TYPE_CALL_SERVICE = 'call_service'
TYPE_EVENT = 'event'
TYPE_GET_CONFIG = 'get_config'
TYPE_GET_SERVICES = 'get_services'
TYPE_GET_STATES = 'get_states'
TYPE_PING = 'ping'
TYPE_PONG = 'pong'
TYPE_RESULT = 'result'
TYPE_SUBSCRIBE_EVENTS = 'subscribe_events'
TYPE_UNSUBSCRIBE_EVENTS = 'unsubscribe_events'
_LOGGER = logging.getLogger(__name__)
JSON_DUMP = partial(json.dumps, cls=JSONEncoder)
TYPE_AUTH = 'auth'
TYPE_AUTH_INVALID = 'auth_invalid'
TYPE_AUTH_OK = 'auth_ok'
TYPE_AUTH_REQUIRED = 'auth_required'
# Backwards compat
# pylint: disable=invalid-name
BASE_COMMAND_MESSAGE_SCHEMA = messages.BASE_COMMAND_MESSAGE_SCHEMA
error_message = messages.error_message
result_message = messages.result_message
async_response = decorators.async_response
ws_require_user = decorators.ws_require_user
# pylint: enable=invalid-name
AUTH_MESSAGE_SCHEMA = vol.Schema({
vol.Required('type'): TYPE_AUTH,
vol.Exclusive('api_password', 'auth'): str,
vol.Exclusive('access_token', 'auth'): str,
})
# Minimal requirements of a message
MINIMAL_MESSAGE_SCHEMA = vol.Schema({
vol.Required('id'): cv.positive_int,
vol.Required('type'): cv.string,
}, extra=vol.ALLOW_EXTRA)
# Base schema to extend by message handlers
BASE_COMMAND_MESSAGE_SCHEMA = vol.Schema({
vol.Required('id'): cv.positive_int,
})
SCHEMA_SUBSCRIBE_EVENTS = BASE_COMMAND_MESSAGE_SCHEMA.extend({
vol.Required('type'): TYPE_SUBSCRIBE_EVENTS,
vol.Optional('event_type', default=MATCH_ALL): str,
})
SCHEMA_UNSUBSCRIBE_EVENTS = BASE_COMMAND_MESSAGE_SCHEMA.extend({
vol.Required('type'): TYPE_UNSUBSCRIBE_EVENTS,
vol.Required('subscription'): cv.positive_int,
})
SCHEMA_CALL_SERVICE = BASE_COMMAND_MESSAGE_SCHEMA.extend({
vol.Required('type'): TYPE_CALL_SERVICE,
vol.Required('domain'): str,
vol.Required('service'): str,
vol.Optional('service_data'): dict
})
SCHEMA_GET_STATES = BASE_COMMAND_MESSAGE_SCHEMA.extend({
vol.Required('type'): TYPE_GET_STATES,
})
SCHEMA_GET_SERVICES = BASE_COMMAND_MESSAGE_SCHEMA.extend({
vol.Required('type'): TYPE_GET_SERVICES,
})
SCHEMA_GET_CONFIG = BASE_COMMAND_MESSAGE_SCHEMA.extend({
vol.Required('type'): TYPE_GET_CONFIG,
})
SCHEMA_PING = BASE_COMMAND_MESSAGE_SCHEMA.extend({
vol.Required('type'): TYPE_PING,
})
# Define the possible errors that occur when connections are cancelled.
# Originally, this was just asyncio.CancelledError, but issue #9546 showed
@ -148,46 +91,6 @@ def auth_invalid_message(message):
}
def event_message(iden, event):
"""Return an event message."""
return {
'id': iden,
'type': TYPE_EVENT,
'event': event.as_dict(),
}
def error_message(iden, code, message):
"""Return an error result message."""
return {
'id': iden,
'type': TYPE_RESULT,
'success': False,
'error': {
'code': code,
'message': message,
},
}
def pong_message(iden):
"""Return a pong message."""
return {
'id': iden,
'type': TYPE_PONG,
}
def result_message(iden, result=None):
"""Return a success result message."""
return {
'id': iden,
'type': TYPE_RESULT,
'success': True,
'result': result,
}
@bind_hass
@callback
def async_register_command(hass, command, handler, schema):
@ -198,43 +101,10 @@ def async_register_command(hass, command, handler, schema):
handlers[command] = (handler, schema)
def require_owner(func):
"""Websocket decorator to require user to be an owner."""
@wraps(func)
def with_owner(hass, connection, msg):
"""Check owner and call function."""
user = connection.request.get('hass_user')
if user is None or not user.is_owner:
connection.to_write.put_nowait(error_message(
msg['id'], 'unauthorized', 'This command is for owners only.'))
return
func(hass, connection, msg)
return with_owner
async def async_setup(hass, config):
"""Initialize the websocket API."""
hass.http.register_view(WebsocketAPIView)
async_register_command(hass, TYPE_SUBSCRIBE_EVENTS,
handle_subscribe_events, SCHEMA_SUBSCRIBE_EVENTS)
async_register_command(hass, TYPE_UNSUBSCRIBE_EVENTS,
handle_unsubscribe_events,
SCHEMA_UNSUBSCRIBE_EVENTS)
async_register_command(hass, TYPE_CALL_SERVICE,
handle_call_service, SCHEMA_CALL_SERVICE)
async_register_command(hass, TYPE_GET_STATES,
handle_get_states, SCHEMA_GET_STATES)
async_register_command(hass, TYPE_GET_SERVICES,
handle_get_services, SCHEMA_GET_SERVICES)
async_register_command(hass, TYPE_GET_CONFIG,
handle_get_config, SCHEMA_GET_CONFIG)
async_register_command(hass, TYPE_PING,
handle_ping, SCHEMA_PING)
commands.async_register_commands(hass)
return True
@ -389,19 +259,19 @@ class ActiveConnection:
while msg:
self.debug("Received", msg)
msg = MINIMAL_MESSAGE_SCHEMA(msg)
msg = messages.MINIMAL_MESSAGE_SCHEMA(msg)
cur_id = msg['id']
if cur_id <= last_id:
self.to_write.put_nowait(error_message(
cur_id, ERR_ID_REUSE,
self.to_write.put_nowait(messages.error_message(
cur_id, const.ERR_ID_REUSE,
'Identifier values have to increase.'))
elif msg['type'] not in handlers:
self.log_error(
'Received invalid command: {}'.format(msg['type']))
self.to_write.put_nowait(error_message(
cur_id, ERR_UNKNOWN_COMMAND,
self.to_write.put_nowait(messages.error_message(
cur_id, const.ERR_UNKNOWN_COMMAND,
'Unknown command.'))
else:
@ -410,8 +280,8 @@ class ActiveConnection:
handler(self.hass, self, schema(msg))
except Exception: # pylint: disable=broad-except
_LOGGER.exception('Error handling message: %s', msg)
self.to_write.put_nowait(error_message(
cur_id, ERR_UNKNOWN_ERROR,
self.to_write.put_nowait(messages.error_message(
cur_id, const.ERR_UNKNOWN_ERROR,
'Unknown error.'))
last_id = cur_id
@ -435,8 +305,8 @@ class ActiveConnection:
else:
iden = None
final_message = error_message(
iden, ERR_INVALID_FORMAT, error_msg)
final_message = messages.error_message(
iden, const.ERR_INVALID_FORMAT, error_msg)
except TypeError as err:
if wsock.closed:
@ -485,170 +355,3 @@ class ActiveConnection:
self.debug("Closed connection")
return wsock
def async_response(func):
"""Decorate an async function to handle WebSocket API messages."""
async def handle_msg_response(hass, connection, msg):
"""Create a response and handle exception."""
try:
await func(hass, connection, msg)
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Unexpected exception")
connection.send_message_outside(error_message(
msg['id'], 'unknown', 'Unexpected error occurred'))
@callback
@wraps(func)
def schedule_handler(hass, connection, msg):
"""Schedule the handler."""
hass.async_create_task(handle_msg_response(hass, connection, msg))
return schedule_handler
@callback
def handle_subscribe_events(hass, connection, msg):
"""Handle subscribe events command.
Async friendly.
"""
async def forward_events(event):
"""Forward events to websocket."""
if event.event_type == EVENT_TIME_CHANGED:
return
connection.send_message_outside(event_message(msg['id'], event))
connection.event_listeners[msg['id']] = hass.bus.async_listen(
msg['event_type'], forward_events)
connection.to_write.put_nowait(result_message(msg['id']))
@callback
def handle_unsubscribe_events(hass, connection, msg):
"""Handle unsubscribe events command.
Async friendly.
"""
subscription = msg['subscription']
if subscription in connection.event_listeners:
connection.event_listeners.pop(subscription)()
connection.to_write.put_nowait(result_message(msg['id']))
else:
connection.to_write.put_nowait(error_message(
msg['id'], ERR_NOT_FOUND, 'Subscription not found.'))
@async_response
async def handle_call_service(hass, connection, msg):
"""Handle call service command.
Async friendly.
"""
blocking = True
if (msg['domain'] == 'homeassistant' and
msg['service'] in ['restart', 'stop']):
blocking = False
await hass.services.async_call(
msg['domain'], msg['service'], msg.get('service_data'), blocking,
connection.context(msg))
connection.send_message_outside(result_message(msg['id']))
@callback
def handle_get_states(hass, connection, msg):
"""Handle get states command.
Async friendly.
"""
connection.to_write.put_nowait(result_message(
msg['id'], hass.states.async_all()))
@async_response
async def handle_get_services(hass, connection, msg):
"""Handle get services command.
Async friendly.
"""
descriptions = await async_get_all_descriptions(hass)
connection.send_message_outside(
result_message(msg['id'], descriptions))
@callback
def handle_get_config(hass, connection, msg):
"""Handle get config command.
Async friendly.
"""
connection.to_write.put_nowait(result_message(
msg['id'], hass.config.as_dict()))
@callback
def handle_ping(hass, connection, msg):
"""Handle ping command.
Async friendly.
"""
connection.to_write.put_nowait(pong_message(msg['id']))
def ws_require_user(
only_owner=False, only_system_user=False, allow_system_user=True,
only_active_user=True, only_inactive_user=False):
"""Decorate function validating login user exist in current WS connection.
Will write out error message if not authenticated.
"""
def validator(func):
"""Decorate func."""
@wraps(func)
def check_current_user(hass: HomeAssistant,
connection: ActiveConnection,
msg):
"""Check current user."""
def output_error(message_id, message):
"""Output error message."""
connection.send_message_outside(error_message(
msg['id'], message_id, message))
if connection.user is None:
output_error('no_user', 'Not authenticated as a user')
return
if only_owner and not connection.user.is_owner:
output_error('only_owner', 'Only allowed as owner')
return
if (only_system_user and
not connection.user.system_generated):
output_error('only_system_user',
'Only allowed as system user')
return
if (not allow_system_user
and connection.user.system_generated):
output_error('not_system_user', 'Not allowed as system user')
return
if (only_active_user and
not connection.user.is_active):
output_error('only_active_user',
'Only allowed as active user')
return
if only_inactive_user and connection.user.is_active:
output_error('only_inactive_user',
'Not allowed as active user')
return
return func(hass, connection, msg)
return check_current_user
return validator

View File

@ -0,0 +1,183 @@
"""Commands part of Websocket API."""
import voluptuous as vol
from homeassistant.const import MATCH_ALL, EVENT_TIME_CHANGED
from homeassistant.core import callback
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.service import async_get_all_descriptions
from . import const, decorators, messages
TYPE_CALL_SERVICE = 'call_service'
TYPE_EVENT = 'event'
TYPE_GET_CONFIG = 'get_config'
TYPE_GET_SERVICES = 'get_services'
TYPE_GET_STATES = 'get_states'
TYPE_PING = 'ping'
TYPE_PONG = 'pong'
TYPE_SUBSCRIBE_EVENTS = 'subscribe_events'
TYPE_UNSUBSCRIBE_EVENTS = 'unsubscribe_events'
@callback
def async_register_commands(hass):
"""Register commands."""
async_reg = hass.components.websocket_api.async_register_command
async_reg(TYPE_SUBSCRIBE_EVENTS, handle_subscribe_events,
SCHEMA_SUBSCRIBE_EVENTS)
async_reg(TYPE_UNSUBSCRIBE_EVENTS, handle_unsubscribe_events,
SCHEMA_UNSUBSCRIBE_EVENTS)
async_reg(TYPE_CALL_SERVICE, handle_call_service, SCHEMA_CALL_SERVICE)
async_reg(TYPE_GET_STATES, handle_get_states, SCHEMA_GET_STATES)
async_reg(TYPE_GET_SERVICES, handle_get_services, SCHEMA_GET_SERVICES)
async_reg(TYPE_GET_CONFIG, handle_get_config, SCHEMA_GET_CONFIG)
async_reg(TYPE_PING, handle_ping, SCHEMA_PING)
SCHEMA_SUBSCRIBE_EVENTS = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend({
vol.Required('type'): TYPE_SUBSCRIBE_EVENTS,
vol.Optional('event_type', default=MATCH_ALL): str,
})
SCHEMA_UNSUBSCRIBE_EVENTS = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend({
vol.Required('type'): TYPE_UNSUBSCRIBE_EVENTS,
vol.Required('subscription'): cv.positive_int,
})
SCHEMA_CALL_SERVICE = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend({
vol.Required('type'): TYPE_CALL_SERVICE,
vol.Required('domain'): str,
vol.Required('service'): str,
vol.Optional('service_data'): dict
})
SCHEMA_GET_STATES = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend({
vol.Required('type'): TYPE_GET_STATES,
})
SCHEMA_GET_SERVICES = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend({
vol.Required('type'): TYPE_GET_SERVICES,
})
SCHEMA_GET_CONFIG = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend({
vol.Required('type'): TYPE_GET_CONFIG,
})
SCHEMA_PING = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend({
vol.Required('type'): TYPE_PING,
})
def event_message(iden, event):
"""Return an event message."""
return {
'id': iden,
'type': TYPE_EVENT,
'event': event.as_dict(),
}
def pong_message(iden):
"""Return a pong message."""
return {
'id': iden,
'type': TYPE_PONG,
}
@callback
def handle_subscribe_events(hass, connection, msg):
"""Handle subscribe events command.
Async friendly.
"""
async def forward_events(event):
"""Forward events to websocket."""
if event.event_type == EVENT_TIME_CHANGED:
return
connection.send_message_outside(event_message(msg['id'], event))
connection.event_listeners[msg['id']] = hass.bus.async_listen(
msg['event_type'], forward_events)
connection.to_write.put_nowait(messages.result_message(msg['id']))
@callback
def handle_unsubscribe_events(hass, connection, msg):
"""Handle unsubscribe events command.
Async friendly.
"""
subscription = msg['subscription']
if subscription in connection.event_listeners:
connection.event_listeners.pop(subscription)()
connection.to_write.put_nowait(messages.result_message(msg['id']))
else:
connection.to_write.put_nowait(messages.error_message(
msg['id'], const.ERR_NOT_FOUND, 'Subscription not found.'))
@decorators.async_response
async def handle_call_service(hass, connection, msg):
"""Handle call service command.
Async friendly.
"""
blocking = True
if (msg['domain'] == 'homeassistant' and
msg['service'] in ['restart', 'stop']):
blocking = False
await hass.services.async_call(
msg['domain'], msg['service'], msg.get('service_data'), blocking,
connection.context(msg))
connection.send_message_outside(messages.result_message(msg['id']))
@callback
def handle_get_states(hass, connection, msg):
"""Handle get states command.
Async friendly.
"""
connection.to_write.put_nowait(messages.result_message(
msg['id'], hass.states.async_all()))
@decorators.async_response
async def handle_get_services(hass, connection, msg):
"""Handle get services command.
Async friendly.
"""
descriptions = await async_get_all_descriptions(hass)
connection.send_message_outside(
messages.result_message(msg['id'], descriptions))
@callback
def handle_get_config(hass, connection, msg):
"""Handle get config command.
Async friendly.
"""
connection.to_write.put_nowait(messages.result_message(
msg['id'], hass.config.as_dict()))
@callback
def handle_ping(hass, connection, msg):
"""Handle ping command.
Async friendly.
"""
connection.to_write.put_nowait(pong_message(msg['id']))

View File

@ -0,0 +1,8 @@
"""Websocket constants."""
ERR_ID_REUSE = 1
ERR_INVALID_FORMAT = 2
ERR_NOT_FOUND = 3
ERR_UNKNOWN_COMMAND = 4
ERR_UNKNOWN_ERROR = 5
TYPE_RESULT = 'result'

View File

@ -0,0 +1,101 @@
"""Decorators for the Websocket API."""
from functools import wraps
import logging
from homeassistant.core import callback
from . import messages
_LOGGER = logging.getLogger(__name__)
def async_response(func):
"""Decorate an async function to handle WebSocket API messages."""
async def handle_msg_response(hass, connection, msg):
"""Create a response and handle exception."""
try:
await func(hass, connection, msg)
except Exception: # pylint: disable=broad-except
_LOGGER.exception("Unexpected exception")
connection.send_message_outside(messages.error_message(
msg['id'], 'unknown', 'Unexpected error occurred'))
@callback
@wraps(func)
def schedule_handler(hass, connection, msg):
"""Schedule the handler."""
hass.async_create_task(handle_msg_response(hass, connection, msg))
return schedule_handler
def require_owner(func):
"""Websocket decorator to require user to be an owner."""
@wraps(func)
def with_owner(hass, connection, msg):
"""Check owner and call function."""
user = connection.request.get('hass_user')
if user is None or not user.is_owner:
connection.to_write.put_nowait(messages.error_message(
msg['id'], 'unauthorized', 'This command is for owners only.'))
return
func(hass, connection, msg)
return with_owner
def ws_require_user(
only_owner=False, only_system_user=False, allow_system_user=True,
only_active_user=True, only_inactive_user=False):
"""Decorate function validating login user exist in current WS connection.
Will write out error message if not authenticated.
"""
def validator(func):
"""Decorate func."""
@wraps(func)
def check_current_user(hass, connection, msg):
"""Check current user."""
def output_error(message_id, message):
"""Output error message."""
connection.send_message_outside(messages.error_message(
msg['id'], message_id, message))
if connection.user is None:
output_error('no_user', 'Not authenticated as a user')
return
if only_owner and not connection.user.is_owner:
output_error('only_owner', 'Only allowed as owner')
return
if (only_system_user and
not connection.user.system_generated):
output_error('only_system_user',
'Only allowed as system user')
return
if (not allow_system_user
and connection.user.system_generated):
output_error('not_system_user', 'Not allowed as system user')
return
if (only_active_user and
not connection.user.is_active):
output_error('only_active_user',
'Only allowed as active user')
return
if only_inactive_user and connection.user.is_active:
output_error('only_inactive_user',
'Not allowed as active user')
return
return func(hass, connection, msg)
return check_current_user
return validator

View File

@ -0,0 +1,42 @@
"""Message templates for websocket commands."""
import voluptuous as vol
from homeassistant.helpers import config_validation as cv
from . import const
# Minimal requirements of a message
MINIMAL_MESSAGE_SCHEMA = vol.Schema({
vol.Required('id'): cv.positive_int,
vol.Required('type'): cv.string,
}, extra=vol.ALLOW_EXTRA)
# Base schema to extend by message handlers
BASE_COMMAND_MESSAGE_SCHEMA = vol.Schema({
vol.Required('id'): cv.positive_int,
})
def result_message(iden, result=None):
"""Return a success result message."""
return {
'id': iden,
'type': const.TYPE_RESULT,
'success': True,
'result': result,
}
def error_message(iden, code, message):
"""Return an error result message."""
return {
'id': iden,
'type': const.TYPE_RESULT,
'success': False,
'error': {
'code': code,
'message': message,
},
}

View File

@ -7,7 +7,8 @@ import pytest
from homeassistant.setup import setup_component, async_setup_component
from homeassistant.const import ATTR_ENTITY_PICTURE
from homeassistant.components import camera, http, websocket_api
from homeassistant.components import camera, http
from homeassistant.components.websocket_api.const import TYPE_RESULT
from homeassistant.exceptions import HomeAssistantError
from homeassistant.util.async_ import run_coroutine_threadsafe
@ -150,7 +151,7 @@ async def test_webocket_camera_thumbnail(hass, hass_ws_client, mock_camera):
msg = await client.receive_json()
assert msg['id'] == 5
assert msg['type'] == websocket_api.TYPE_RESULT
assert msg['type'] == TYPE_RESULT
assert msg['success']
assert msg['result']['content_type'] == 'image/jpeg'
assert msg['result']['content'] == \

View File

@ -9,7 +9,7 @@ from homeassistant.setup import async_setup_component
from homeassistant.components.frontend import (
DOMAIN, CONF_JS_VERSION, CONF_THEMES, CONF_EXTRA_HTML_URL,
CONF_EXTRA_HTML_URL_ES5)
from homeassistant.components import websocket_api as wapi
from homeassistant.components.websocket_api.const import TYPE_RESULT
from tests.common import mock_coro
@ -213,7 +213,7 @@ async def test_missing_themes(hass, hass_ws_client):
msg = await client.receive_json()
assert msg['id'] == 5
assert msg['type'] == wapi.TYPE_RESULT
assert msg['type'] == TYPE_RESULT
assert msg['success']
assert msg['result']['default_theme'] == 'default'
assert msg['result']['themes'] == {}
@ -252,7 +252,7 @@ async def test_get_panels(hass, hass_ws_client):
msg = await client.receive_json()
assert msg['id'] == 5
assert msg['type'] == wapi.TYPE_RESULT
assert msg['type'] == TYPE_RESULT
assert msg['success']
assert msg['result']['map']['component_name'] == 'map'
assert msg['result']['map']['url_path'] == 'map'
@ -275,7 +275,7 @@ async def test_get_translations(hass, hass_ws_client):
msg = await client.receive_json()
assert msg['id'] == 5
assert msg['type'] == wapi.TYPE_RESULT
assert msg['type'] == TYPE_RESULT
assert msg['success']
assert msg['result'] == {'resources': {'lang': 'nl'}}

View File

@ -3,7 +3,7 @@ from unittest.mock import patch
from homeassistant.exceptions import HomeAssistantError
from homeassistant.setup import async_setup_component
from homeassistant.components import websocket_api as wapi
from homeassistant.components.websocket_api.const import TYPE_RESULT
async def test_deprecated_lovelace_ui(hass, hass_ws_client):
@ -20,7 +20,7 @@ async def test_deprecated_lovelace_ui(hass, hass_ws_client):
msg = await client.receive_json()
assert msg['id'] == 5
assert msg['type'] == wapi.TYPE_RESULT
assert msg['type'] == TYPE_RESULT
assert msg['success']
assert msg['result'] == {'hello': 'world'}
@ -39,7 +39,7 @@ async def test_deprecated_lovelace_ui_not_found(hass, hass_ws_client):
msg = await client.receive_json()
assert msg['id'] == 5
assert msg['type'] == wapi.TYPE_RESULT
assert msg['type'] == TYPE_RESULT
assert msg['success'] is False
assert msg['error']['code'] == 'file_not_found'
@ -58,7 +58,7 @@ async def test_deprecated_lovelace_ui_load_err(hass, hass_ws_client):
msg = await client.receive_json()
assert msg['id'] == 5
assert msg['type'] == wapi.TYPE_RESULT
assert msg['type'] == TYPE_RESULT
assert msg['success'] is False
assert msg['error']['code'] == 'load_error'
@ -77,7 +77,7 @@ async def test_lovelace_ui(hass, hass_ws_client):
msg = await client.receive_json()
assert msg['id'] == 5
assert msg['type'] == wapi.TYPE_RESULT
assert msg['type'] == TYPE_RESULT
assert msg['success']
assert msg['result'] == {'hello': 'world'}
@ -96,7 +96,7 @@ async def test_lovelace_ui_not_found(hass, hass_ws_client):
msg = await client.receive_json()
assert msg['id'] == 5
assert msg['type'] == wapi.TYPE_RESULT
assert msg['type'] == TYPE_RESULT
assert msg['success'] is False
assert msg['error']['code'] == 'file_not_found'
@ -115,6 +115,6 @@ async def test_lovelace_ui_load_err(hass, hass_ws_client):
msg = await client.receive_json()
assert msg['id'] == 5
assert msg['type'] == wapi.TYPE_RESULT
assert msg['type'] == TYPE_RESULT
assert msg['success'] is False
assert msg['error']['code'] == 'load_error'

View File

@ -3,7 +3,7 @@ import base64
from unittest.mock import patch
from homeassistant.setup import async_setup_component
from homeassistant.components import websocket_api
from homeassistant.components.websocket_api.const import TYPE_RESULT
from tests.common import mock_coro
@ -30,7 +30,7 @@ async def test_get_panels(hass, hass_ws_client):
msg = await client.receive_json()
assert msg['id'] == 5
assert msg['type'] == websocket_api.TYPE_RESULT
assert msg['type'] == TYPE_RESULT
assert msg['success']
assert msg['result']['content_type'] == 'image/jpeg'
assert msg['result']['content'] == \

View File

@ -1,5 +1,5 @@
"""The tests for the persistent notification component."""
from homeassistant.components import websocket_api
from homeassistant.components.websocket_api.const import TYPE_RESULT
from homeassistant.setup import setup_component, async_setup_component
import homeassistant.components.persistent_notification as pn
@ -151,7 +151,7 @@ async def test_ws_get_notifications(hass, hass_ws_client):
})
msg = await client.receive_json()
assert msg['id'] == 5
assert msg['type'] == websocket_api.TYPE_RESULT
assert msg['type'] == TYPE_RESULT
assert msg['success']
notifications = msg['result']
assert len(notifications) == 0
@ -165,7 +165,7 @@ async def test_ws_get_notifications(hass, hass_ws_client):
})
msg = await client.receive_json()
assert msg['id'] == 6
assert msg['type'] == websocket_api.TYPE_RESULT
assert msg['type'] == TYPE_RESULT
assert msg['success']
notifications = msg['result']
assert len(notifications) == 1

View File

@ -1,558 +0,0 @@
"""Tests for the Home Assistant Websocket API."""
import asyncio
from unittest.mock import patch, Mock
from aiohttp import WSMsgType
from async_timeout import timeout
import pytest
from homeassistant.core import callback
from homeassistant.components import websocket_api as wapi
from homeassistant.setup import async_setup_component
from tests.common import mock_coro, async_mock_service
API_PASSWORD = 'test1234'
@pytest.fixture
def websocket_client(hass, hass_ws_client):
"""Create a websocket client."""
return hass.loop.run_until_complete(hass_ws_client(hass))
@pytest.fixture
def no_auth_websocket_client(hass, loop, aiohttp_client):
"""Websocket connection that requires authentication."""
assert loop.run_until_complete(
async_setup_component(hass, 'websocket_api', {
'http': {
'api_password': API_PASSWORD
}
}))
client = loop.run_until_complete(aiohttp_client(hass.http.app))
ws = loop.run_until_complete(client.ws_connect(wapi.URL))
auth_ok = loop.run_until_complete(ws.receive_json())
assert auth_ok['type'] == wapi.TYPE_AUTH_REQUIRED
yield ws
if not ws.closed:
loop.run_until_complete(ws.close())
@pytest.fixture
def mock_low_queue():
"""Mock a low queue."""
with patch.object(wapi, 'MAX_PENDING_MSG', 5):
yield
@asyncio.coroutine
def test_auth_via_msg(no_auth_websocket_client):
"""Test authenticating."""
yield from no_auth_websocket_client.send_json({
'type': wapi.TYPE_AUTH,
'api_password': API_PASSWORD
})
msg = yield from no_auth_websocket_client.receive_json()
assert msg['type'] == wapi.TYPE_AUTH_OK
@asyncio.coroutine
def test_auth_via_msg_incorrect_pass(no_auth_websocket_client):
"""Test authenticating."""
with patch('homeassistant.components.websocket_api.process_wrong_login',
return_value=mock_coro()) as mock_process_wrong_login:
yield from no_auth_websocket_client.send_json({
'type': wapi.TYPE_AUTH,
'api_password': API_PASSWORD + 'wrong'
})
msg = yield from no_auth_websocket_client.receive_json()
assert mock_process_wrong_login.called
assert msg['type'] == wapi.TYPE_AUTH_INVALID
assert msg['message'] == 'Invalid access token or password'
@asyncio.coroutine
def test_pre_auth_only_auth_allowed(no_auth_websocket_client):
"""Verify that before authentication, only auth messages are allowed."""
yield from no_auth_websocket_client.send_json({
'type': wapi.TYPE_CALL_SERVICE,
'domain': 'domain_test',
'service': 'test_service',
'service_data': {
'hello': 'world'
}
})
msg = yield from no_auth_websocket_client.receive_json()
assert msg['type'] == wapi.TYPE_AUTH_INVALID
assert msg['message'].startswith('Message incorrectly formatted')
@asyncio.coroutine
def test_invalid_message_format(websocket_client):
"""Test sending invalid JSON."""
yield from websocket_client.send_json({'type': 5})
msg = yield from websocket_client.receive_json()
assert msg['type'] == wapi.TYPE_RESULT
error = msg['error']
assert error['code'] == wapi.ERR_INVALID_FORMAT
assert error['message'].startswith('Message incorrectly formatted')
@asyncio.coroutine
def test_invalid_json(websocket_client):
"""Test sending invalid JSON."""
yield from websocket_client.send_str('this is not JSON')
msg = yield from websocket_client.receive()
assert msg.type == WSMsgType.close
@asyncio.coroutine
def test_quiting_hass(hass, websocket_client):
"""Test sending invalid JSON."""
with patch.object(hass.loop, 'stop'):
yield from hass.async_stop()
msg = yield from websocket_client.receive()
assert msg.type == WSMsgType.CLOSE
@asyncio.coroutine
def test_call_service(hass, websocket_client):
"""Test call service command."""
calls = []
@callback
def service_call(call):
calls.append(call)
hass.services.async_register('domain_test', 'test_service', service_call)
yield from websocket_client.send_json({
'id': 5,
'type': wapi.TYPE_CALL_SERVICE,
'domain': 'domain_test',
'service': 'test_service',
'service_data': {
'hello': 'world'
}
})
msg = yield from websocket_client.receive_json()
assert msg['id'] == 5
assert msg['type'] == wapi.TYPE_RESULT
assert msg['success']
assert len(calls) == 1
call = calls[0]
assert call.domain == 'domain_test'
assert call.service == 'test_service'
assert call.data == {'hello': 'world'}
@asyncio.coroutine
def test_subscribe_unsubscribe_events(hass, websocket_client):
"""Test subscribe/unsubscribe events command."""
init_count = sum(hass.bus.async_listeners().values())
yield from websocket_client.send_json({
'id': 5,
'type': wapi.TYPE_SUBSCRIBE_EVENTS,
'event_type': 'test_event'
})
msg = yield from websocket_client.receive_json()
assert msg['id'] == 5
assert msg['type'] == wapi.TYPE_RESULT
assert msg['success']
# Verify we have a new listener
assert sum(hass.bus.async_listeners().values()) == init_count + 1
hass.bus.async_fire('ignore_event')
hass.bus.async_fire('test_event', {'hello': 'world'})
hass.bus.async_fire('ignore_event')
with timeout(3, loop=hass.loop):
msg = yield from websocket_client.receive_json()
assert msg['id'] == 5
assert msg['type'] == wapi.TYPE_EVENT
event = msg['event']
assert event['event_type'] == 'test_event'
assert event['data'] == {'hello': 'world'}
assert event['origin'] == 'LOCAL'
yield from websocket_client.send_json({
'id': 6,
'type': wapi.TYPE_UNSUBSCRIBE_EVENTS,
'subscription': 5
})
msg = yield from websocket_client.receive_json()
assert msg['id'] == 6
assert msg['type'] == wapi.TYPE_RESULT
assert msg['success']
# Check our listener got unsubscribed
assert sum(hass.bus.async_listeners().values()) == init_count
@asyncio.coroutine
def test_get_states(hass, websocket_client):
"""Test get_states command."""
hass.states.async_set('greeting.hello', 'world')
hass.states.async_set('greeting.bye', 'universe')
yield from websocket_client.send_json({
'id': 5,
'type': wapi.TYPE_GET_STATES,
})
msg = yield from websocket_client.receive_json()
assert msg['id'] == 5
assert msg['type'] == wapi.TYPE_RESULT
assert msg['success']
states = []
for state in hass.states.async_all():
state = state.as_dict()
state['last_changed'] = state['last_changed'].isoformat()
state['last_updated'] = state['last_updated'].isoformat()
states.append(state)
assert msg['result'] == states
@asyncio.coroutine
def test_get_services(hass, websocket_client):
"""Test get_services command."""
yield from websocket_client.send_json({
'id': 5,
'type': wapi.TYPE_GET_SERVICES,
})
msg = yield from websocket_client.receive_json()
assert msg['id'] == 5
assert msg['type'] == wapi.TYPE_RESULT
assert msg['success']
assert msg['result'] == hass.services.async_services()
@asyncio.coroutine
def test_get_config(hass, websocket_client):
"""Test get_config command."""
yield from websocket_client.send_json({
'id': 5,
'type': wapi.TYPE_GET_CONFIG,
})
msg = yield from websocket_client.receive_json()
assert msg['id'] == 5
assert msg['type'] == wapi.TYPE_RESULT
assert msg['success']
if 'components' in msg['result']:
msg['result']['components'] = set(msg['result']['components'])
if 'whitelist_external_dirs' in msg['result']:
msg['result']['whitelist_external_dirs'] = \
set(msg['result']['whitelist_external_dirs'])
assert msg['result'] == hass.config.as_dict()
@asyncio.coroutine
def test_ping(websocket_client):
"""Test get_panels command."""
yield from websocket_client.send_json({
'id': 5,
'type': wapi.TYPE_PING,
})
msg = yield from websocket_client.receive_json()
assert msg['id'] == 5
assert msg['type'] == wapi.TYPE_PONG
@asyncio.coroutine
def test_pending_msg_overflow(hass, mock_low_queue, websocket_client):
"""Test get_panels command."""
for idx in range(10):
yield from websocket_client.send_json({
'id': idx + 1,
'type': wapi.TYPE_PING,
})
msg = yield from websocket_client.receive()
assert msg.type == WSMsgType.close
@asyncio.coroutine
def test_unknown_command(websocket_client):
"""Test get_panels command."""
yield from websocket_client.send_json({
'id': 5,
'type': 'unknown_command',
})
msg = yield from websocket_client.receive_json()
assert not msg['success']
assert msg['error']['code'] == wapi.ERR_UNKNOWN_COMMAND
async def test_auth_active_with_token(hass, aiohttp_client, hass_access_token):
"""Test authenticating with a token."""
assert await async_setup_component(hass, 'websocket_api', {
'http': {
'api_password': API_PASSWORD
}
})
client = await aiohttp_client(hass.http.app)
async with client.ws_connect(wapi.URL) as ws:
with patch('homeassistant.auth.AuthManager.active') as auth_active:
auth_active.return_value = True
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
await ws.send_json({
'type': wapi.TYPE_AUTH,
'access_token': hass_access_token
})
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_OK
async def test_auth_active_user_inactive(hass, aiohttp_client,
hass_access_token):
"""Test authenticating with a token."""
refresh_token = await hass.auth.async_validate_access_token(
hass_access_token)
refresh_token.user.is_active = False
assert await async_setup_component(hass, 'websocket_api', {
'http': {
'api_password': API_PASSWORD
}
})
client = await aiohttp_client(hass.http.app)
async with client.ws_connect(wapi.URL) as ws:
with patch('homeassistant.auth.AuthManager.active') as auth_active:
auth_active.return_value = True
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
await ws.send_json({
'type': wapi.TYPE_AUTH,
'access_token': hass_access_token
})
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_INVALID
async def test_auth_active_with_password_not_allow(hass, aiohttp_client):
"""Test authenticating with a token."""
assert await async_setup_component(hass, 'websocket_api', {
'http': {
'api_password': API_PASSWORD
}
})
client = await aiohttp_client(hass.http.app)
async with client.ws_connect(wapi.URL) as ws:
with patch('homeassistant.auth.AuthManager.active',
return_value=True):
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
await ws.send_json({
'type': wapi.TYPE_AUTH,
'api_password': API_PASSWORD
})
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_INVALID
async def test_auth_legacy_support_with_password(hass, aiohttp_client):
"""Test authenticating with a token."""
assert await async_setup_component(hass, 'websocket_api', {
'http': {
'api_password': API_PASSWORD
}
})
client = await aiohttp_client(hass.http.app)
async with client.ws_connect(wapi.URL) as ws:
with patch('homeassistant.auth.AuthManager.active',
return_value=True),\
patch('homeassistant.auth.AuthManager.support_legacy',
return_value=True):
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
await ws.send_json({
'type': wapi.TYPE_AUTH,
'api_password': API_PASSWORD
})
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_OK
async def test_auth_with_invalid_token(hass, aiohttp_client):
"""Test authenticating with a token."""
assert await async_setup_component(hass, 'websocket_api', {
'http': {
'api_password': API_PASSWORD
}
})
client = await aiohttp_client(hass.http.app)
async with client.ws_connect(wapi.URL) as ws:
with patch('homeassistant.auth.AuthManager.active') as auth_active:
auth_active.return_value = True
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
await ws.send_json({
'type': wapi.TYPE_AUTH,
'access_token': 'incorrect'
})
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_INVALID
async def test_call_service_context_with_user(hass, aiohttp_client,
hass_access_token):
"""Test that the user is set in the service call context."""
assert await async_setup_component(hass, 'websocket_api', {
'http': {
'api_password': API_PASSWORD
}
})
calls = async_mock_service(hass, 'domain_test', 'test_service')
client = await aiohttp_client(hass.http.app)
async with client.ws_connect(wapi.URL) as ws:
with patch('homeassistant.auth.AuthManager.active') as auth_active:
auth_active.return_value = True
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
await ws.send_json({
'type': wapi.TYPE_AUTH,
'access_token': hass_access_token
})
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_OK
await ws.send_json({
'id': 5,
'type': wapi.TYPE_CALL_SERVICE,
'domain': 'domain_test',
'service': 'test_service',
'service_data': {
'hello': 'world'
}
})
msg = await ws.receive_json()
assert msg['success']
refresh_token = await hass.auth.async_validate_access_token(
hass_access_token)
assert len(calls) == 1
call = calls[0]
assert call.domain == 'domain_test'
assert call.service == 'test_service'
assert call.data == {'hello': 'world'}
assert call.context.user_id == refresh_token.user.id
async def test_call_service_context_no_user(hass, aiohttp_client):
"""Test that connection without user sets context."""
assert await async_setup_component(hass, 'websocket_api', {
'http': {
'api_password': API_PASSWORD
}
})
calls = async_mock_service(hass, 'domain_test', 'test_service')
client = await aiohttp_client(hass.http.app)
async with client.ws_connect(wapi.URL) as ws:
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
await ws.send_json({
'type': wapi.TYPE_AUTH,
'api_password': API_PASSWORD
})
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_OK
await ws.send_json({
'id': 5,
'type': wapi.TYPE_CALL_SERVICE,
'domain': 'domain_test',
'service': 'test_service',
'service_data': {
'hello': 'world'
}
})
msg = await ws.receive_json()
assert msg['success']
assert len(calls) == 1
call = calls[0]
assert call.domain == 'domain_test'
assert call.service == 'test_service'
assert call.data == {'hello': 'world'}
assert call.context.user_id is None
async def test_handler_failing(hass, websocket_client):
"""Test a command that raises."""
hass.components.websocket_api.async_register_command(
'bla', Mock(side_effect=TypeError),
wapi.BASE_COMMAND_MESSAGE_SCHEMA.extend({'type': 'bla'}))
await websocket_client.send_json({
'id': 5,
'type': 'bla',
})
msg = await websocket_client.receive_json()
assert msg['id'] == 5
assert msg['type'] == wapi.TYPE_RESULT
assert not msg['success']
assert msg['error']['code'] == wapi.ERR_UNKNOWN_ERROR

View File

@ -0,0 +1,2 @@
"""Tests for the websocket API."""
API_PASSWORD = 'test1234'

View File

@ -0,0 +1,35 @@
"""Fixtures for websocket tests."""
import pytest
from homeassistant.setup import async_setup_component
from homeassistant.components import websocket_api as wapi
from . import API_PASSWORD
@pytest.fixture
def websocket_client(hass, hass_ws_client):
"""Create a websocket client."""
return hass.loop.run_until_complete(hass_ws_client(hass))
@pytest.fixture
def no_auth_websocket_client(hass, loop, aiohttp_client):
"""Websocket connection that requires authentication."""
assert loop.run_until_complete(
async_setup_component(hass, 'websocket_api', {
'http': {
'api_password': API_PASSWORD
}
}))
client = loop.run_until_complete(aiohttp_client(hass.http.app))
ws = loop.run_until_complete(client.ws_connect(wapi.URL))
auth_ok = loop.run_until_complete(ws.receive_json())
assert auth_ok['type'] == wapi.TYPE_AUTH_REQUIRED
yield ws
if not ws.closed:
loop.run_until_complete(ws.close())

View File

@ -0,0 +1,186 @@
"""Test auth of websocket API."""
from unittest.mock import patch
from homeassistant.components import websocket_api as wapi
from homeassistant.components.websocket_api import commands
from homeassistant.setup import async_setup_component
from tests.common import mock_coro
from . import API_PASSWORD
async def test_auth_via_msg(no_auth_websocket_client):
"""Test authenticating."""
await no_auth_websocket_client.send_json({
'type': wapi.TYPE_AUTH,
'api_password': API_PASSWORD
})
msg = await no_auth_websocket_client.receive_json()
assert msg['type'] == wapi.TYPE_AUTH_OK
async def test_auth_via_msg_incorrect_pass(no_auth_websocket_client):
"""Test authenticating."""
with patch('homeassistant.components.websocket_api.process_wrong_login',
return_value=mock_coro()) as mock_process_wrong_login:
await no_auth_websocket_client.send_json({
'type': wapi.TYPE_AUTH,
'api_password': API_PASSWORD + 'wrong'
})
msg = await no_auth_websocket_client.receive_json()
assert mock_process_wrong_login.called
assert msg['type'] == wapi.TYPE_AUTH_INVALID
assert msg['message'] == 'Invalid access token or password'
async def test_pre_auth_only_auth_allowed(no_auth_websocket_client):
"""Verify that before authentication, only auth messages are allowed."""
await no_auth_websocket_client.send_json({
'type': commands.TYPE_CALL_SERVICE,
'domain': 'domain_test',
'service': 'test_service',
'service_data': {
'hello': 'world'
}
})
msg = await no_auth_websocket_client.receive_json()
assert msg['type'] == wapi.TYPE_AUTH_INVALID
assert msg['message'].startswith('Message incorrectly formatted')
async def test_auth_active_with_token(hass, aiohttp_client, hass_access_token):
"""Test authenticating with a token."""
assert await async_setup_component(hass, 'websocket_api', {
'http': {
'api_password': API_PASSWORD
}
})
client = await aiohttp_client(hass.http.app)
async with client.ws_connect(wapi.URL) as ws:
with patch('homeassistant.auth.AuthManager.active') as auth_active:
auth_active.return_value = True
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
await ws.send_json({
'type': wapi.TYPE_AUTH,
'access_token': hass_access_token
})
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_OK
async def test_auth_active_user_inactive(hass, aiohttp_client,
hass_access_token):
"""Test authenticating with a token."""
refresh_token = await hass.auth.async_validate_access_token(
hass_access_token)
refresh_token.user.is_active = False
assert await async_setup_component(hass, 'websocket_api', {
'http': {
'api_password': API_PASSWORD
}
})
client = await aiohttp_client(hass.http.app)
async with client.ws_connect(wapi.URL) as ws:
with patch('homeassistant.auth.AuthManager.active') as auth_active:
auth_active.return_value = True
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
await ws.send_json({
'type': wapi.TYPE_AUTH,
'access_token': hass_access_token
})
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_INVALID
async def test_auth_active_with_password_not_allow(hass, aiohttp_client):
"""Test authenticating with a token."""
assert await async_setup_component(hass, 'websocket_api', {
'http': {
'api_password': API_PASSWORD
}
})
client = await aiohttp_client(hass.http.app)
async with client.ws_connect(wapi.URL) as ws:
with patch('homeassistant.auth.AuthManager.active',
return_value=True):
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
await ws.send_json({
'type': wapi.TYPE_AUTH,
'api_password': API_PASSWORD
})
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_INVALID
async def test_auth_legacy_support_with_password(hass, aiohttp_client):
"""Test authenticating with a token."""
assert await async_setup_component(hass, 'websocket_api', {
'http': {
'api_password': API_PASSWORD
}
})
client = await aiohttp_client(hass.http.app)
async with client.ws_connect(wapi.URL) as ws:
with patch('homeassistant.auth.AuthManager.active',
return_value=True),\
patch('homeassistant.auth.AuthManager.support_legacy',
return_value=True):
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
await ws.send_json({
'type': wapi.TYPE_AUTH,
'api_password': API_PASSWORD
})
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_OK
async def test_auth_with_invalid_token(hass, aiohttp_client):
"""Test authenticating with a token."""
assert await async_setup_component(hass, 'websocket_api', {
'http': {
'api_password': API_PASSWORD
}
})
client = await aiohttp_client(hass.http.app)
async with client.ws_connect(wapi.URL) as ws:
with patch('homeassistant.auth.AuthManager.active') as auth_active:
auth_active.return_value = True
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
await ws.send_json({
'type': wapi.TYPE_AUTH,
'access_token': 'incorrect'
})
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_INVALID

View File

@ -0,0 +1,260 @@
"""Tests for WebSocket API commands."""
from unittest.mock import patch
from async_timeout import timeout
from homeassistant.core import callback
from homeassistant.components import websocket_api as wapi
from homeassistant.components.websocket_api import const, commands
from homeassistant.setup import async_setup_component
from tests.common import async_mock_service
from . import API_PASSWORD
async def test_call_service(hass, websocket_client):
"""Test call service command."""
calls = []
@callback
def service_call(call):
calls.append(call)
hass.services.async_register('domain_test', 'test_service', service_call)
await websocket_client.send_json({
'id': 5,
'type': commands.TYPE_CALL_SERVICE,
'domain': 'domain_test',
'service': 'test_service',
'service_data': {
'hello': 'world'
}
})
msg = await websocket_client.receive_json()
assert msg['id'] == 5
assert msg['type'] == const.TYPE_RESULT
assert msg['success']
assert len(calls) == 1
call = calls[0]
assert call.domain == 'domain_test'
assert call.service == 'test_service'
assert call.data == {'hello': 'world'}
async def test_subscribe_unsubscribe_events(hass, websocket_client):
"""Test subscribe/unsubscribe events command."""
init_count = sum(hass.bus.async_listeners().values())
await websocket_client.send_json({
'id': 5,
'type': commands.TYPE_SUBSCRIBE_EVENTS,
'event_type': 'test_event'
})
msg = await websocket_client.receive_json()
assert msg['id'] == 5
assert msg['type'] == const.TYPE_RESULT
assert msg['success']
# Verify we have a new listener
assert sum(hass.bus.async_listeners().values()) == init_count + 1
hass.bus.async_fire('ignore_event')
hass.bus.async_fire('test_event', {'hello': 'world'})
hass.bus.async_fire('ignore_event')
with timeout(3, loop=hass.loop):
msg = await websocket_client.receive_json()
assert msg['id'] == 5
assert msg['type'] == commands.TYPE_EVENT
event = msg['event']
assert event['event_type'] == 'test_event'
assert event['data'] == {'hello': 'world'}
assert event['origin'] == 'LOCAL'
await websocket_client.send_json({
'id': 6,
'type': commands.TYPE_UNSUBSCRIBE_EVENTS,
'subscription': 5
})
msg = await websocket_client.receive_json()
assert msg['id'] == 6
assert msg['type'] == const.TYPE_RESULT
assert msg['success']
# Check our listener got unsubscribed
assert sum(hass.bus.async_listeners().values()) == init_count
async def test_get_states(hass, websocket_client):
"""Test get_states command."""
hass.states.async_set('greeting.hello', 'world')
hass.states.async_set('greeting.bye', 'universe')
await websocket_client.send_json({
'id': 5,
'type': commands.TYPE_GET_STATES,
})
msg = await websocket_client.receive_json()
assert msg['id'] == 5
assert msg['type'] == const.TYPE_RESULT
assert msg['success']
states = []
for state in hass.states.async_all():
state = state.as_dict()
state['last_changed'] = state['last_changed'].isoformat()
state['last_updated'] = state['last_updated'].isoformat()
states.append(state)
assert msg['result'] == states
async def test_get_services(hass, websocket_client):
"""Test get_services command."""
await websocket_client.send_json({
'id': 5,
'type': commands.TYPE_GET_SERVICES,
})
msg = await websocket_client.receive_json()
assert msg['id'] == 5
assert msg['type'] == const.TYPE_RESULT
assert msg['success']
assert msg['result'] == hass.services.async_services()
async def test_get_config(hass, websocket_client):
"""Test get_config command."""
await websocket_client.send_json({
'id': 5,
'type': commands.TYPE_GET_CONFIG,
})
msg = await websocket_client.receive_json()
assert msg['id'] == 5
assert msg['type'] == const.TYPE_RESULT
assert msg['success']
if 'components' in msg['result']:
msg['result']['components'] = set(msg['result']['components'])
if 'whitelist_external_dirs' in msg['result']:
msg['result']['whitelist_external_dirs'] = \
set(msg['result']['whitelist_external_dirs'])
assert msg['result'] == hass.config.as_dict()
async def test_ping(websocket_client):
"""Test get_panels command."""
await websocket_client.send_json({
'id': 5,
'type': commands.TYPE_PING,
})
msg = await websocket_client.receive_json()
assert msg['id'] == 5
assert msg['type'] == commands.TYPE_PONG
async def test_call_service_context_with_user(hass, aiohttp_client,
hass_access_token):
"""Test that the user is set in the service call context."""
assert await async_setup_component(hass, 'websocket_api', {
'http': {
'api_password': API_PASSWORD
}
})
calls = async_mock_service(hass, 'domain_test', 'test_service')
client = await aiohttp_client(hass.http.app)
async with client.ws_connect(wapi.URL) as ws:
with patch('homeassistant.auth.AuthManager.active') as auth_active:
auth_active.return_value = True
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
await ws.send_json({
'type': wapi.TYPE_AUTH,
'access_token': hass_access_token
})
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_OK
await ws.send_json({
'id': 5,
'type': commands.TYPE_CALL_SERVICE,
'domain': 'domain_test',
'service': 'test_service',
'service_data': {
'hello': 'world'
}
})
msg = await ws.receive_json()
assert msg['success']
refresh_token = await hass.auth.async_validate_access_token(
hass_access_token)
assert len(calls) == 1
call = calls[0]
assert call.domain == 'domain_test'
assert call.service == 'test_service'
assert call.data == {'hello': 'world'}
assert call.context.user_id == refresh_token.user.id
async def test_call_service_context_no_user(hass, aiohttp_client):
"""Test that connection without user sets context."""
assert await async_setup_component(hass, 'websocket_api', {
'http': {
'api_password': API_PASSWORD
}
})
calls = async_mock_service(hass, 'domain_test', 'test_service')
client = await aiohttp_client(hass.http.app)
async with client.ws_connect(wapi.URL) as ws:
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED
await ws.send_json({
'type': wapi.TYPE_AUTH,
'api_password': API_PASSWORD
})
auth_msg = await ws.receive_json()
assert auth_msg['type'] == wapi.TYPE_AUTH_OK
await ws.send_json({
'id': 5,
'type': commands.TYPE_CALL_SERVICE,
'domain': 'domain_test',
'service': 'test_service',
'service_data': {
'hello': 'world'
}
})
msg = await ws.receive_json()
assert msg['success']
assert len(calls) == 1
call = calls[0]
assert call.domain == 'domain_test'
assert call.service == 'test_service'
assert call.data == {'hello': 'world'}
assert call.context.user_id is None

View File

@ -0,0 +1,92 @@
"""Tests for the Home Assistant Websocket API."""
import asyncio
from unittest.mock import patch, Mock
from aiohttp import WSMsgType
import pytest
from homeassistant.components import websocket_api as wapi
from homeassistant.components.websocket_api import const, commands, messages
@pytest.fixture
def mock_low_queue():
"""Mock a low queue."""
with patch.object(wapi, 'MAX_PENDING_MSG', 5):
yield
@asyncio.coroutine
def test_invalid_message_format(websocket_client):
"""Test sending invalid JSON."""
yield from websocket_client.send_json({'type': 5})
msg = yield from websocket_client.receive_json()
assert msg['type'] == const.TYPE_RESULT
error = msg['error']
assert error['code'] == const.ERR_INVALID_FORMAT
assert error['message'].startswith('Message incorrectly formatted')
@asyncio.coroutine
def test_invalid_json(websocket_client):
"""Test sending invalid JSON."""
yield from websocket_client.send_str('this is not JSON')
msg = yield from websocket_client.receive()
assert msg.type == WSMsgType.close
@asyncio.coroutine
def test_quiting_hass(hass, websocket_client):
"""Test sending invalid JSON."""
with patch.object(hass.loop, 'stop'):
yield from hass.async_stop()
msg = yield from websocket_client.receive()
assert msg.type == WSMsgType.CLOSE
@asyncio.coroutine
def test_pending_msg_overflow(hass, mock_low_queue, websocket_client):
"""Test get_panels command."""
for idx in range(10):
yield from websocket_client.send_json({
'id': idx + 1,
'type': commands.TYPE_PING,
})
msg = yield from websocket_client.receive()
assert msg.type == WSMsgType.close
@asyncio.coroutine
def test_unknown_command(websocket_client):
"""Test get_panels command."""
yield from websocket_client.send_json({
'id': 5,
'type': 'unknown_command',
})
msg = yield from websocket_client.receive_json()
assert not msg['success']
assert msg['error']['code'] == const.ERR_UNKNOWN_COMMAND
async def test_handler_failing(hass, websocket_client):
"""Test a command that raises."""
hass.components.websocket_api.async_register_command(
'bla', Mock(side_effect=TypeError),
messages.BASE_COMMAND_MESSAGE_SCHEMA.extend({'type': 'bla'}))
await websocket_client.send_json({
'id': 5,
'type': 'bla',
})
msg = await websocket_client.receive_json()
assert msg['id'] == 5
assert msg['type'] == const.TYPE_RESULT
assert not msg['success']
assert msg['error']['code'] == const.ERR_UNKNOWN_ERROR