diff --git a/homeassistant/components/camera/__init__.py b/homeassistant/components/camera/__init__.py index b32236e499d..95f0cddf320 100644 --- a/homeassistant/components/camera/__init__.py +++ b/homeassistant/components/camera/__init__.py @@ -452,27 +452,23 @@ class CameraMjpegStream(CameraView): raise web.HTTPBadRequest() -@callback -def websocket_camera_thumbnail(hass, connection, msg): +@websocket_api.async_response +async def websocket_camera_thumbnail(hass, connection, msg): """Handle get camera thumbnail websocket command. Async friendly. """ - async def send_camera_still(): - """Send a camera still.""" - try: - image = await async_get_image(hass, msg['entity_id']) - connection.send_message_outside(websocket_api.result_message( - msg['id'], { - 'content_type': image.content_type, - 'content': base64.b64encode(image.content).decode('utf-8') - } - )) - except HomeAssistantError: - connection.send_message_outside(websocket_api.error_message( - msg['id'], 'image_fetch_failed', 'Unable to fetch image')) - - hass.async_add_job(send_camera_still()) + try: + image = await async_get_image(hass, msg['entity_id']) + connection.send_message_outside(websocket_api.result_message( + msg['id'], { + 'content_type': image.content_type, + 'content': base64.b64encode(image.content).decode('utf-8') + } + )) + except HomeAssistantError: + connection.send_message_outside(websocket_api.error_message( + msg['id'], 'image_fetch_failed', 'Unable to fetch image')) async def async_handle_snapshot_service(camera, service): diff --git a/homeassistant/components/config/auth.py b/homeassistant/components/config/auth.py index 6f00b03dedb..17dd132d4b4 100644 --- a/homeassistant/components/config/auth.py +++ b/homeassistant/components/config/auth.py @@ -3,6 +3,7 @@ import voluptuous as vol from homeassistant.core import callback from homeassistant.components import websocket_api +from homeassistant.components.websocket_api.decorators import require_owner WS_TYPE_LIST = 'config/auth/list' @@ -41,7 +42,7 @@ async def async_setup(hass): @callback -@websocket_api.require_owner +@require_owner def websocket_list(hass, connection, msg): """Return a list of users.""" async def send_users(): @@ -55,7 +56,7 @@ def websocket_list(hass, connection, msg): @callback -@websocket_api.require_owner +@require_owner def websocket_delete(hass, connection, msg): """Delete a user.""" async def delete_user(): @@ -82,7 +83,7 @@ def websocket_delete(hass, connection, msg): @callback -@websocket_api.require_owner +@require_owner def websocket_create(hass, connection, msg): """Create a user.""" async def create_user(): diff --git a/homeassistant/components/config/auth_provider_homeassistant.py b/homeassistant/components/config/auth_provider_homeassistant.py index 960e8f5e7b4..8f0c969a808 100644 --- a/homeassistant/components/config/auth_provider_homeassistant.py +++ b/homeassistant/components/config/auth_provider_homeassistant.py @@ -4,6 +4,7 @@ import voluptuous as vol from homeassistant.auth.providers import homeassistant as auth_ha from homeassistant.core import callback from homeassistant.components import websocket_api +from homeassistant.components.websocket_api.decorators import require_owner WS_TYPE_CREATE = 'config/auth_provider/homeassistant/create' @@ -55,7 +56,7 @@ def _get_provider(hass): @callback -@websocket_api.require_owner +@require_owner def websocket_create(hass, connection, msg): """Create credentials and attach to a user.""" async def create_creds(): @@ -96,7 +97,7 @@ def websocket_create(hass, connection, msg): @callback -@websocket_api.require_owner +@require_owner def websocket_delete(hass, connection, msg): """Delete username and related credential.""" async def delete_creds(): diff --git a/homeassistant/components/config/entity_registry.py b/homeassistant/components/config/entity_registry.py index 0f9abf167e5..18d66ec623a 100644 --- a/homeassistant/components/config/entity_registry.py +++ b/homeassistant/components/config/entity_registry.py @@ -4,6 +4,8 @@ import voluptuous as vol from homeassistant.core import callback from homeassistant.helpers.entity_registry import async_get_registry from homeassistant.components import websocket_api +from homeassistant.components.websocket_api.const import ERR_NOT_FOUND +from homeassistant.components.websocket_api.decorators import async_response from homeassistant.helpers import config_validation as cv DEPENDENCIES = ['websocket_api'] @@ -46,89 +48,77 @@ async def async_setup(hass): return True -@callback -def websocket_list_entities(hass, connection, msg): +@async_response +async def websocket_list_entities(hass, connection, msg): """Handle list registry entries command. Async friendly. """ - async def retrieve_entities(): - """Get entities from registry.""" - registry = await async_get_registry(hass) - connection.send_message_outside(websocket_api.result_message( - msg['id'], [{ - 'config_entry_id': entry.config_entry_id, - 'device_id': entry.device_id, - 'disabled_by': entry.disabled_by, - 'entity_id': entry.entity_id, - 'name': entry.name, - 'platform': entry.platform, - } for entry in registry.entities.values()] - )) - - hass.async_add_job(retrieve_entities()) + registry = await async_get_registry(hass) + connection.send_message_outside(websocket_api.result_message( + msg['id'], [{ + 'config_entry_id': entry.config_entry_id, + 'device_id': entry.device_id, + 'disabled_by': entry.disabled_by, + 'entity_id': entry.entity_id, + 'name': entry.name, + 'platform': entry.platform, + } for entry in registry.entities.values()] + )) -@callback -def websocket_get_entity(hass, connection, msg): +@async_response +async def websocket_get_entity(hass, connection, msg): """Handle get entity registry entry command. Async friendly. """ - async def retrieve_entity(): - """Get entity from registry.""" - registry = await async_get_registry(hass) - entry = registry.entities.get(msg['entity_id']) + registry = await async_get_registry(hass) + entry = registry.entities.get(msg['entity_id']) - if entry is None: - connection.send_message_outside(websocket_api.error_message( - msg['id'], websocket_api.ERR_NOT_FOUND, 'Entity not found')) - return + if entry is None: + connection.send_message_outside(websocket_api.error_message( + msg['id'], ERR_NOT_FOUND, 'Entity not found')) + return - connection.send_message_outside(websocket_api.result_message( - msg['id'], _entry_dict(entry) - )) - - hass.async_add_job(retrieve_entity()) + connection.send_message_outside(websocket_api.result_message( + msg['id'], _entry_dict(entry) + )) -@callback -def websocket_update_entity(hass, connection, msg): +@async_response +async def websocket_update_entity(hass, connection, msg): """Handle get camera thumbnail websocket command. Async friendly. """ - async def update_entity(): - """Get entity from registry.""" - registry = await async_get_registry(hass) + registry = await async_get_registry(hass) - if msg['entity_id'] not in registry.entities: - connection.send_message_outside(websocket_api.error_message( - msg['id'], websocket_api.ERR_NOT_FOUND, 'Entity not found')) - return + if msg['entity_id'] not in registry.entities: + connection.send_message_outside(websocket_api.error_message( + msg['id'], ERR_NOT_FOUND, 'Entity not found')) + return - changes = {} + changes = {} - if 'name' in msg: - changes['name'] = msg['name'] + if 'name' in msg: + changes['name'] = msg['name'] - if 'new_entity_id' in msg: - changes['new_entity_id'] = msg['new_entity_id'] + if 'new_entity_id' in msg: + changes['new_entity_id'] = msg['new_entity_id'] - try: - if changes: - entry = registry.async_update_entity( - msg['entity_id'], **changes) - except ValueError as err: - connection.send_message_outside(websocket_api.error_message( - msg['id'], 'invalid_info', str(err) - )) - else: - connection.send_message_outside(websocket_api.result_message( - msg['id'], _entry_dict(entry) - )) - - hass.async_create_task(update_entity()) + try: + if changes: + entry = registry.async_update_entity( + msg['entity_id'], **changes) + except ValueError as err: + connection.send_message_outside(websocket_api.error_message( + msg['id'], 'invalid_info', str(err) + )) + else: + connection.send_message_outside(websocket_api.result_message( + msg['id'], _entry_dict(entry) + )) @callback diff --git a/homeassistant/components/media_player/__init__.py b/homeassistant/components/media_player/__init__.py index 235ca8d5b2d..85016df7262 100644 --- a/homeassistant/components/media_player/__init__.py +++ b/homeassistant/components/media_player/__init__.py @@ -28,7 +28,6 @@ from homeassistant.const import ( SERVICE_TOGGLE, SERVICE_TURN_OFF, SERVICE_TURN_ON, SERVICE_VOLUME_DOWN, SERVICE_VOLUME_MUTE, SERVICE_VOLUME_SET, SERVICE_VOLUME_UP, STATE_IDLE, STATE_OFF, STATE_PLAYING, STATE_UNKNOWN) -from homeassistant.core import callback from homeassistant.helpers.aiohttp_client import async_get_clientsession import homeassistant.helpers.config_validation as cv from homeassistant.helpers.config_validation import PLATFORM_SCHEMA # noqa @@ -865,8 +864,8 @@ class MediaPlayerImageView(HomeAssistantView): body=data, content_type=content_type, headers=headers) -@callback -def websocket_handle_thumbnail(hass, connection, msg): +@websocket_api.async_response +async def websocket_handle_thumbnail(hass, connection, msg): """Handle get media player cover command. Async friendly. @@ -879,20 +878,16 @@ def websocket_handle_thumbnail(hass, connection, msg): msg['id'], 'entity_not_found', 'Entity not found')) return - async def send_image(): - """Send image.""" - data, content_type = await player.async_get_media_image() + data, content_type = await player.async_get_media_image() - if data is None: - connection.send_message_outside(websocket_api.error_message( - msg['id'], 'thumbnail_fetch_failed', - 'Failed to fetch thumbnail')) - return + if data is None: + connection.send_message_outside(websocket_api.error_message( + msg['id'], 'thumbnail_fetch_failed', + 'Failed to fetch thumbnail')) + return - connection.send_message_outside(websocket_api.result_message( - msg['id'], { - 'content_type': content_type, - 'content': base64.b64encode(data).decode('utf-8') - })) - - hass.async_add_job(send_image()) + connection.send_message_outside(websocket_api.result_message( + msg['id'], { + 'content_type': content_type, + 'content': base64.b64encode(data).decode('utf-8') + })) diff --git a/homeassistant/components/websocket_api.py b/homeassistant/components/websocket_api/__init__.py similarity index 52% rename from homeassistant/components/websocket_api.py rename to homeassistant/components/websocket_api/__init__.py index 4e7c186facc..448256e31fd 100644 --- a/homeassistant/components/websocket_api.py +++ b/homeassistant/components/websocket_api/__init__.py @@ -7,7 +7,7 @@ https://developers.home-assistant.io/docs/external_api_websocket.html import asyncio from concurrent import futures from contextlib import suppress -from functools import partial, wraps +from functools import partial import json import logging @@ -15,20 +15,18 @@ from aiohttp import web import voluptuous as vol from voluptuous.humanize import humanize_error -from homeassistant.const import ( - MATCH_ALL, EVENT_TIME_CHANGED, EVENT_HOMEASSISTANT_STOP, - __version__) -from homeassistant.core import Context, callback, HomeAssistant +from homeassistant.const import EVENT_HOMEASSISTANT_STOP, __version__ +from homeassistant.core import Context, callback from homeassistant.loader import bind_hass from homeassistant.helpers.json import JSONEncoder -from homeassistant.helpers import config_validation as cv -from homeassistant.helpers.service import async_get_all_descriptions from homeassistant.components.http import HomeAssistantView from homeassistant.components.http.auth import validate_password from homeassistant.components.http.const import KEY_AUTHENTICATED from homeassistant.components.http.ban import process_wrong_login, \ process_success_login +from . import commands, const, decorators, messages + DOMAIN = 'websocket_api' URL = '/api/websocket' @@ -36,87 +34,32 @@ DEPENDENCIES = ('http',) MAX_PENDING_MSG = 512 -ERR_ID_REUSE = 1 -ERR_INVALID_FORMAT = 2 -ERR_NOT_FOUND = 3 -ERR_UNKNOWN_COMMAND = 4 -ERR_UNKNOWN_ERROR = 5 - -TYPE_AUTH = 'auth' -TYPE_AUTH_INVALID = 'auth_invalid' -TYPE_AUTH_OK = 'auth_ok' -TYPE_AUTH_REQUIRED = 'auth_required' -TYPE_CALL_SERVICE = 'call_service' -TYPE_EVENT = 'event' -TYPE_GET_CONFIG = 'get_config' -TYPE_GET_SERVICES = 'get_services' -TYPE_GET_STATES = 'get_states' -TYPE_PING = 'ping' -TYPE_PONG = 'pong' -TYPE_RESULT = 'result' -TYPE_SUBSCRIBE_EVENTS = 'subscribe_events' -TYPE_UNSUBSCRIBE_EVENTS = 'unsubscribe_events' _LOGGER = logging.getLogger(__name__) JSON_DUMP = partial(json.dumps, cls=JSONEncoder) +TYPE_AUTH = 'auth' +TYPE_AUTH_INVALID = 'auth_invalid' +TYPE_AUTH_OK = 'auth_ok' +TYPE_AUTH_REQUIRED = 'auth_required' + + +# Backwards compat +# pylint: disable=invalid-name +BASE_COMMAND_MESSAGE_SCHEMA = messages.BASE_COMMAND_MESSAGE_SCHEMA +error_message = messages.error_message +result_message = messages.result_message +async_response = decorators.async_response +ws_require_user = decorators.ws_require_user +# pylint: enable=invalid-name + AUTH_MESSAGE_SCHEMA = vol.Schema({ vol.Required('type'): TYPE_AUTH, vol.Exclusive('api_password', 'auth'): str, vol.Exclusive('access_token', 'auth'): str, }) -# Minimal requirements of a message -MINIMAL_MESSAGE_SCHEMA = vol.Schema({ - vol.Required('id'): cv.positive_int, - vol.Required('type'): cv.string, -}, extra=vol.ALLOW_EXTRA) -# Base schema to extend by message handlers -BASE_COMMAND_MESSAGE_SCHEMA = vol.Schema({ - vol.Required('id'): cv.positive_int, -}) - - -SCHEMA_SUBSCRIBE_EVENTS = BASE_COMMAND_MESSAGE_SCHEMA.extend({ - vol.Required('type'): TYPE_SUBSCRIBE_EVENTS, - vol.Optional('event_type', default=MATCH_ALL): str, -}) - - -SCHEMA_UNSUBSCRIBE_EVENTS = BASE_COMMAND_MESSAGE_SCHEMA.extend({ - vol.Required('type'): TYPE_UNSUBSCRIBE_EVENTS, - vol.Required('subscription'): cv.positive_int, -}) - - -SCHEMA_CALL_SERVICE = BASE_COMMAND_MESSAGE_SCHEMA.extend({ - vol.Required('type'): TYPE_CALL_SERVICE, - vol.Required('domain'): str, - vol.Required('service'): str, - vol.Optional('service_data'): dict -}) - - -SCHEMA_GET_STATES = BASE_COMMAND_MESSAGE_SCHEMA.extend({ - vol.Required('type'): TYPE_GET_STATES, -}) - - -SCHEMA_GET_SERVICES = BASE_COMMAND_MESSAGE_SCHEMA.extend({ - vol.Required('type'): TYPE_GET_SERVICES, -}) - - -SCHEMA_GET_CONFIG = BASE_COMMAND_MESSAGE_SCHEMA.extend({ - vol.Required('type'): TYPE_GET_CONFIG, -}) - - -SCHEMA_PING = BASE_COMMAND_MESSAGE_SCHEMA.extend({ - vol.Required('type'): TYPE_PING, -}) - # Define the possible errors that occur when connections are cancelled. # Originally, this was just asyncio.CancelledError, but issue #9546 showed @@ -148,46 +91,6 @@ def auth_invalid_message(message): } -def event_message(iden, event): - """Return an event message.""" - return { - 'id': iden, - 'type': TYPE_EVENT, - 'event': event.as_dict(), - } - - -def error_message(iden, code, message): - """Return an error result message.""" - return { - 'id': iden, - 'type': TYPE_RESULT, - 'success': False, - 'error': { - 'code': code, - 'message': message, - }, - } - - -def pong_message(iden): - """Return a pong message.""" - return { - 'id': iden, - 'type': TYPE_PONG, - } - - -def result_message(iden, result=None): - """Return a success result message.""" - return { - 'id': iden, - 'type': TYPE_RESULT, - 'success': True, - 'result': result, - } - - @bind_hass @callback def async_register_command(hass, command, handler, schema): @@ -198,43 +101,10 @@ def async_register_command(hass, command, handler, schema): handlers[command] = (handler, schema) -def require_owner(func): - """Websocket decorator to require user to be an owner.""" - @wraps(func) - def with_owner(hass, connection, msg): - """Check owner and call function.""" - user = connection.request.get('hass_user') - - if user is None or not user.is_owner: - connection.to_write.put_nowait(error_message( - msg['id'], 'unauthorized', 'This command is for owners only.')) - return - - func(hass, connection, msg) - - return with_owner - - async def async_setup(hass, config): """Initialize the websocket API.""" hass.http.register_view(WebsocketAPIView) - - async_register_command(hass, TYPE_SUBSCRIBE_EVENTS, - handle_subscribe_events, SCHEMA_SUBSCRIBE_EVENTS) - async_register_command(hass, TYPE_UNSUBSCRIBE_EVENTS, - handle_unsubscribe_events, - SCHEMA_UNSUBSCRIBE_EVENTS) - async_register_command(hass, TYPE_CALL_SERVICE, - handle_call_service, SCHEMA_CALL_SERVICE) - async_register_command(hass, TYPE_GET_STATES, - handle_get_states, SCHEMA_GET_STATES) - async_register_command(hass, TYPE_GET_SERVICES, - handle_get_services, SCHEMA_GET_SERVICES) - async_register_command(hass, TYPE_GET_CONFIG, - handle_get_config, SCHEMA_GET_CONFIG) - async_register_command(hass, TYPE_PING, - handle_ping, SCHEMA_PING) - + commands.async_register_commands(hass) return True @@ -389,19 +259,19 @@ class ActiveConnection: while msg: self.debug("Received", msg) - msg = MINIMAL_MESSAGE_SCHEMA(msg) + msg = messages.MINIMAL_MESSAGE_SCHEMA(msg) cur_id = msg['id'] if cur_id <= last_id: - self.to_write.put_nowait(error_message( - cur_id, ERR_ID_REUSE, + self.to_write.put_nowait(messages.error_message( + cur_id, const.ERR_ID_REUSE, 'Identifier values have to increase.')) elif msg['type'] not in handlers: self.log_error( 'Received invalid command: {}'.format(msg['type'])) - self.to_write.put_nowait(error_message( - cur_id, ERR_UNKNOWN_COMMAND, + self.to_write.put_nowait(messages.error_message( + cur_id, const.ERR_UNKNOWN_COMMAND, 'Unknown command.')) else: @@ -410,8 +280,8 @@ class ActiveConnection: handler(self.hass, self, schema(msg)) except Exception: # pylint: disable=broad-except _LOGGER.exception('Error handling message: %s', msg) - self.to_write.put_nowait(error_message( - cur_id, ERR_UNKNOWN_ERROR, + self.to_write.put_nowait(messages.error_message( + cur_id, const.ERR_UNKNOWN_ERROR, 'Unknown error.')) last_id = cur_id @@ -435,8 +305,8 @@ class ActiveConnection: else: iden = None - final_message = error_message( - iden, ERR_INVALID_FORMAT, error_msg) + final_message = messages.error_message( + iden, const.ERR_INVALID_FORMAT, error_msg) except TypeError as err: if wsock.closed: @@ -485,170 +355,3 @@ class ActiveConnection: self.debug("Closed connection") return wsock - - -def async_response(func): - """Decorate an async function to handle WebSocket API messages.""" - async def handle_msg_response(hass, connection, msg): - """Create a response and handle exception.""" - try: - await func(hass, connection, msg) - except Exception: # pylint: disable=broad-except - _LOGGER.exception("Unexpected exception") - connection.send_message_outside(error_message( - msg['id'], 'unknown', 'Unexpected error occurred')) - - @callback - @wraps(func) - def schedule_handler(hass, connection, msg): - """Schedule the handler.""" - hass.async_create_task(handle_msg_response(hass, connection, msg)) - - return schedule_handler - - -@callback -def handle_subscribe_events(hass, connection, msg): - """Handle subscribe events command. - - Async friendly. - """ - async def forward_events(event): - """Forward events to websocket.""" - if event.event_type == EVENT_TIME_CHANGED: - return - - connection.send_message_outside(event_message(msg['id'], event)) - - connection.event_listeners[msg['id']] = hass.bus.async_listen( - msg['event_type'], forward_events) - - connection.to_write.put_nowait(result_message(msg['id'])) - - -@callback -def handle_unsubscribe_events(hass, connection, msg): - """Handle unsubscribe events command. - - Async friendly. - """ - subscription = msg['subscription'] - - if subscription in connection.event_listeners: - connection.event_listeners.pop(subscription)() - connection.to_write.put_nowait(result_message(msg['id'])) - else: - connection.to_write.put_nowait(error_message( - msg['id'], ERR_NOT_FOUND, 'Subscription not found.')) - - -@async_response -async def handle_call_service(hass, connection, msg): - """Handle call service command. - - Async friendly. - """ - blocking = True - if (msg['domain'] == 'homeassistant' and - msg['service'] in ['restart', 'stop']): - blocking = False - await hass.services.async_call( - msg['domain'], msg['service'], msg.get('service_data'), blocking, - connection.context(msg)) - connection.send_message_outside(result_message(msg['id'])) - - -@callback -def handle_get_states(hass, connection, msg): - """Handle get states command. - - Async friendly. - """ - connection.to_write.put_nowait(result_message( - msg['id'], hass.states.async_all())) - - -@async_response -async def handle_get_services(hass, connection, msg): - """Handle get services command. - - Async friendly. - """ - descriptions = await async_get_all_descriptions(hass) - connection.send_message_outside( - result_message(msg['id'], descriptions)) - - -@callback -def handle_get_config(hass, connection, msg): - """Handle get config command. - - Async friendly. - """ - connection.to_write.put_nowait(result_message( - msg['id'], hass.config.as_dict())) - - -@callback -def handle_ping(hass, connection, msg): - """Handle ping command. - - Async friendly. - """ - connection.to_write.put_nowait(pong_message(msg['id'])) - - -def ws_require_user( - only_owner=False, only_system_user=False, allow_system_user=True, - only_active_user=True, only_inactive_user=False): - """Decorate function validating login user exist in current WS connection. - - Will write out error message if not authenticated. - """ - def validator(func): - """Decorate func.""" - @wraps(func) - def check_current_user(hass: HomeAssistant, - connection: ActiveConnection, - msg): - """Check current user.""" - def output_error(message_id, message): - """Output error message.""" - connection.send_message_outside(error_message( - msg['id'], message_id, message)) - - if connection.user is None: - output_error('no_user', 'Not authenticated as a user') - return - - if only_owner and not connection.user.is_owner: - output_error('only_owner', 'Only allowed as owner') - return - - if (only_system_user and - not connection.user.system_generated): - output_error('only_system_user', - 'Only allowed as system user') - return - - if (not allow_system_user - and connection.user.system_generated): - output_error('not_system_user', 'Not allowed as system user') - return - - if (only_active_user and - not connection.user.is_active): - output_error('only_active_user', - 'Only allowed as active user') - return - - if only_inactive_user and connection.user.is_active: - output_error('only_inactive_user', - 'Not allowed as active user') - return - - return func(hass, connection, msg) - - return check_current_user - - return validator diff --git a/homeassistant/components/websocket_api/commands.py b/homeassistant/components/websocket_api/commands.py new file mode 100644 index 00000000000..c9808f3a692 --- /dev/null +++ b/homeassistant/components/websocket_api/commands.py @@ -0,0 +1,183 @@ +"""Commands part of Websocket API.""" +import voluptuous as vol + +from homeassistant.const import MATCH_ALL, EVENT_TIME_CHANGED +from homeassistant.core import callback +from homeassistant.helpers import config_validation as cv +from homeassistant.helpers.service import async_get_all_descriptions + +from . import const, decorators, messages + + +TYPE_CALL_SERVICE = 'call_service' +TYPE_EVENT = 'event' +TYPE_GET_CONFIG = 'get_config' +TYPE_GET_SERVICES = 'get_services' +TYPE_GET_STATES = 'get_states' +TYPE_PING = 'ping' +TYPE_PONG = 'pong' +TYPE_SUBSCRIBE_EVENTS = 'subscribe_events' +TYPE_UNSUBSCRIBE_EVENTS = 'unsubscribe_events' + + +@callback +def async_register_commands(hass): + """Register commands.""" + async_reg = hass.components.websocket_api.async_register_command + async_reg(TYPE_SUBSCRIBE_EVENTS, handle_subscribe_events, + SCHEMA_SUBSCRIBE_EVENTS) + async_reg(TYPE_UNSUBSCRIBE_EVENTS, handle_unsubscribe_events, + SCHEMA_UNSUBSCRIBE_EVENTS) + async_reg(TYPE_CALL_SERVICE, handle_call_service, SCHEMA_CALL_SERVICE) + async_reg(TYPE_GET_STATES, handle_get_states, SCHEMA_GET_STATES) + async_reg(TYPE_GET_SERVICES, handle_get_services, SCHEMA_GET_SERVICES) + async_reg(TYPE_GET_CONFIG, handle_get_config, SCHEMA_GET_CONFIG) + async_reg(TYPE_PING, handle_ping, SCHEMA_PING) + + +SCHEMA_SUBSCRIBE_EVENTS = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend({ + vol.Required('type'): TYPE_SUBSCRIBE_EVENTS, + vol.Optional('event_type', default=MATCH_ALL): str, +}) + + +SCHEMA_UNSUBSCRIBE_EVENTS = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend({ + vol.Required('type'): TYPE_UNSUBSCRIBE_EVENTS, + vol.Required('subscription'): cv.positive_int, +}) + + +SCHEMA_CALL_SERVICE = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend({ + vol.Required('type'): TYPE_CALL_SERVICE, + vol.Required('domain'): str, + vol.Required('service'): str, + vol.Optional('service_data'): dict +}) + + +SCHEMA_GET_STATES = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend({ + vol.Required('type'): TYPE_GET_STATES, +}) + + +SCHEMA_GET_SERVICES = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend({ + vol.Required('type'): TYPE_GET_SERVICES, +}) + + +SCHEMA_GET_CONFIG = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend({ + vol.Required('type'): TYPE_GET_CONFIG, +}) + + +SCHEMA_PING = messages.BASE_COMMAND_MESSAGE_SCHEMA.extend({ + vol.Required('type'): TYPE_PING, +}) + + +def event_message(iden, event): + """Return an event message.""" + return { + 'id': iden, + 'type': TYPE_EVENT, + 'event': event.as_dict(), + } + + +def pong_message(iden): + """Return a pong message.""" + return { + 'id': iden, + 'type': TYPE_PONG, + } + + +@callback +def handle_subscribe_events(hass, connection, msg): + """Handle subscribe events command. + + Async friendly. + """ + async def forward_events(event): + """Forward events to websocket.""" + if event.event_type == EVENT_TIME_CHANGED: + return + + connection.send_message_outside(event_message(msg['id'], event)) + + connection.event_listeners[msg['id']] = hass.bus.async_listen( + msg['event_type'], forward_events) + + connection.to_write.put_nowait(messages.result_message(msg['id'])) + + +@callback +def handle_unsubscribe_events(hass, connection, msg): + """Handle unsubscribe events command. + + Async friendly. + """ + subscription = msg['subscription'] + + if subscription in connection.event_listeners: + connection.event_listeners.pop(subscription)() + connection.to_write.put_nowait(messages.result_message(msg['id'])) + else: + connection.to_write.put_nowait(messages.error_message( + msg['id'], const.ERR_NOT_FOUND, 'Subscription not found.')) + + +@decorators.async_response +async def handle_call_service(hass, connection, msg): + """Handle call service command. + + Async friendly. + """ + blocking = True + if (msg['domain'] == 'homeassistant' and + msg['service'] in ['restart', 'stop']): + blocking = False + await hass.services.async_call( + msg['domain'], msg['service'], msg.get('service_data'), blocking, + connection.context(msg)) + connection.send_message_outside(messages.result_message(msg['id'])) + + +@callback +def handle_get_states(hass, connection, msg): + """Handle get states command. + + Async friendly. + """ + connection.to_write.put_nowait(messages.result_message( + msg['id'], hass.states.async_all())) + + +@decorators.async_response +async def handle_get_services(hass, connection, msg): + """Handle get services command. + + Async friendly. + """ + descriptions = await async_get_all_descriptions(hass) + connection.send_message_outside( + messages.result_message(msg['id'], descriptions)) + + +@callback +def handle_get_config(hass, connection, msg): + """Handle get config command. + + Async friendly. + """ + connection.to_write.put_nowait(messages.result_message( + msg['id'], hass.config.as_dict())) + + +@callback +def handle_ping(hass, connection, msg): + """Handle ping command. + + Async friendly. + """ + connection.to_write.put_nowait(pong_message(msg['id'])) diff --git a/homeassistant/components/websocket_api/const.py b/homeassistant/components/websocket_api/const.py new file mode 100644 index 00000000000..cbc56b168c6 --- /dev/null +++ b/homeassistant/components/websocket_api/const.py @@ -0,0 +1,8 @@ +"""Websocket constants.""" +ERR_ID_REUSE = 1 +ERR_INVALID_FORMAT = 2 +ERR_NOT_FOUND = 3 +ERR_UNKNOWN_COMMAND = 4 +ERR_UNKNOWN_ERROR = 5 + +TYPE_RESULT = 'result' diff --git a/homeassistant/components/websocket_api/decorators.py b/homeassistant/components/websocket_api/decorators.py new file mode 100644 index 00000000000..df32dd06d2b --- /dev/null +++ b/homeassistant/components/websocket_api/decorators.py @@ -0,0 +1,101 @@ +"""Decorators for the Websocket API.""" +from functools import wraps +import logging + +from homeassistant.core import callback + +from . import messages + + +_LOGGER = logging.getLogger(__name__) + + +def async_response(func): + """Decorate an async function to handle WebSocket API messages.""" + async def handle_msg_response(hass, connection, msg): + """Create a response and handle exception.""" + try: + await func(hass, connection, msg) + except Exception: # pylint: disable=broad-except + _LOGGER.exception("Unexpected exception") + connection.send_message_outside(messages.error_message( + msg['id'], 'unknown', 'Unexpected error occurred')) + + @callback + @wraps(func) + def schedule_handler(hass, connection, msg): + """Schedule the handler.""" + hass.async_create_task(handle_msg_response(hass, connection, msg)) + + return schedule_handler + + +def require_owner(func): + """Websocket decorator to require user to be an owner.""" + @wraps(func) + def with_owner(hass, connection, msg): + """Check owner and call function.""" + user = connection.request.get('hass_user') + + if user is None or not user.is_owner: + connection.to_write.put_nowait(messages.error_message( + msg['id'], 'unauthorized', 'This command is for owners only.')) + return + + func(hass, connection, msg) + + return with_owner + + +def ws_require_user( + only_owner=False, only_system_user=False, allow_system_user=True, + only_active_user=True, only_inactive_user=False): + """Decorate function validating login user exist in current WS connection. + + Will write out error message if not authenticated. + """ + def validator(func): + """Decorate func.""" + @wraps(func) + def check_current_user(hass, connection, msg): + """Check current user.""" + def output_error(message_id, message): + """Output error message.""" + connection.send_message_outside(messages.error_message( + msg['id'], message_id, message)) + + if connection.user is None: + output_error('no_user', 'Not authenticated as a user') + return + + if only_owner and not connection.user.is_owner: + output_error('only_owner', 'Only allowed as owner') + return + + if (only_system_user and + not connection.user.system_generated): + output_error('only_system_user', + 'Only allowed as system user') + return + + if (not allow_system_user + and connection.user.system_generated): + output_error('not_system_user', 'Not allowed as system user') + return + + if (only_active_user and + not connection.user.is_active): + output_error('only_active_user', + 'Only allowed as active user') + return + + if only_inactive_user and connection.user.is_active: + output_error('only_inactive_user', + 'Not allowed as active user') + return + + return func(hass, connection, msg) + + return check_current_user + + return validator diff --git a/homeassistant/components/websocket_api/messages.py b/homeassistant/components/websocket_api/messages.py new file mode 100644 index 00000000000..d616b6ad670 --- /dev/null +++ b/homeassistant/components/websocket_api/messages.py @@ -0,0 +1,42 @@ +"""Message templates for websocket commands.""" + +import voluptuous as vol + +from homeassistant.helpers import config_validation as cv + +from . import const + + +# Minimal requirements of a message +MINIMAL_MESSAGE_SCHEMA = vol.Schema({ + vol.Required('id'): cv.positive_int, + vol.Required('type'): cv.string, +}, extra=vol.ALLOW_EXTRA) + +# Base schema to extend by message handlers +BASE_COMMAND_MESSAGE_SCHEMA = vol.Schema({ + vol.Required('id'): cv.positive_int, +}) + + +def result_message(iden, result=None): + """Return a success result message.""" + return { + 'id': iden, + 'type': const.TYPE_RESULT, + 'success': True, + 'result': result, + } + + +def error_message(iden, code, message): + """Return an error result message.""" + return { + 'id': iden, + 'type': const.TYPE_RESULT, + 'success': False, + 'error': { + 'code': code, + 'message': message, + }, + } diff --git a/tests/components/camera/test_init.py b/tests/components/camera/test_init.py index 2129e39a43c..6b98f378ef0 100644 --- a/tests/components/camera/test_init.py +++ b/tests/components/camera/test_init.py @@ -7,7 +7,8 @@ import pytest from homeassistant.setup import setup_component, async_setup_component from homeassistant.const import ATTR_ENTITY_PICTURE -from homeassistant.components import camera, http, websocket_api +from homeassistant.components import camera, http +from homeassistant.components.websocket_api.const import TYPE_RESULT from homeassistant.exceptions import HomeAssistantError from homeassistant.util.async_ import run_coroutine_threadsafe @@ -150,7 +151,7 @@ async def test_webocket_camera_thumbnail(hass, hass_ws_client, mock_camera): msg = await client.receive_json() assert msg['id'] == 5 - assert msg['type'] == websocket_api.TYPE_RESULT + assert msg['type'] == TYPE_RESULT assert msg['success'] assert msg['result']['content_type'] == 'image/jpeg' assert msg['result']['content'] == \ diff --git a/tests/components/frontend/test_init.py b/tests/components/frontend/test_init.py index b29c8cfb12f..2e78e0441a3 100644 --- a/tests/components/frontend/test_init.py +++ b/tests/components/frontend/test_init.py @@ -9,7 +9,7 @@ from homeassistant.setup import async_setup_component from homeassistant.components.frontend import ( DOMAIN, CONF_JS_VERSION, CONF_THEMES, CONF_EXTRA_HTML_URL, CONF_EXTRA_HTML_URL_ES5) -from homeassistant.components import websocket_api as wapi +from homeassistant.components.websocket_api.const import TYPE_RESULT from tests.common import mock_coro @@ -213,7 +213,7 @@ async def test_missing_themes(hass, hass_ws_client): msg = await client.receive_json() assert msg['id'] == 5 - assert msg['type'] == wapi.TYPE_RESULT + assert msg['type'] == TYPE_RESULT assert msg['success'] assert msg['result']['default_theme'] == 'default' assert msg['result']['themes'] == {} @@ -252,7 +252,7 @@ async def test_get_panels(hass, hass_ws_client): msg = await client.receive_json() assert msg['id'] == 5 - assert msg['type'] == wapi.TYPE_RESULT + assert msg['type'] == TYPE_RESULT assert msg['success'] assert msg['result']['map']['component_name'] == 'map' assert msg['result']['map']['url_path'] == 'map' @@ -275,7 +275,7 @@ async def test_get_translations(hass, hass_ws_client): msg = await client.receive_json() assert msg['id'] == 5 - assert msg['type'] == wapi.TYPE_RESULT + assert msg['type'] == TYPE_RESULT assert msg['success'] assert msg['result'] == {'resources': {'lang': 'nl'}} diff --git a/tests/components/lovelace/test_init.py b/tests/components/lovelace/test_init.py index 3bb7c0675ea..0fde6de902c 100644 --- a/tests/components/lovelace/test_init.py +++ b/tests/components/lovelace/test_init.py @@ -3,7 +3,7 @@ from unittest.mock import patch from homeassistant.exceptions import HomeAssistantError from homeassistant.setup import async_setup_component -from homeassistant.components import websocket_api as wapi +from homeassistant.components.websocket_api.const import TYPE_RESULT async def test_deprecated_lovelace_ui(hass, hass_ws_client): @@ -20,7 +20,7 @@ async def test_deprecated_lovelace_ui(hass, hass_ws_client): msg = await client.receive_json() assert msg['id'] == 5 - assert msg['type'] == wapi.TYPE_RESULT + assert msg['type'] == TYPE_RESULT assert msg['success'] assert msg['result'] == {'hello': 'world'} @@ -39,7 +39,7 @@ async def test_deprecated_lovelace_ui_not_found(hass, hass_ws_client): msg = await client.receive_json() assert msg['id'] == 5 - assert msg['type'] == wapi.TYPE_RESULT + assert msg['type'] == TYPE_RESULT assert msg['success'] is False assert msg['error']['code'] == 'file_not_found' @@ -58,7 +58,7 @@ async def test_deprecated_lovelace_ui_load_err(hass, hass_ws_client): msg = await client.receive_json() assert msg['id'] == 5 - assert msg['type'] == wapi.TYPE_RESULT + assert msg['type'] == TYPE_RESULT assert msg['success'] is False assert msg['error']['code'] == 'load_error' @@ -77,7 +77,7 @@ async def test_lovelace_ui(hass, hass_ws_client): msg = await client.receive_json() assert msg['id'] == 5 - assert msg['type'] == wapi.TYPE_RESULT + assert msg['type'] == TYPE_RESULT assert msg['success'] assert msg['result'] == {'hello': 'world'} @@ -96,7 +96,7 @@ async def test_lovelace_ui_not_found(hass, hass_ws_client): msg = await client.receive_json() assert msg['id'] == 5 - assert msg['type'] == wapi.TYPE_RESULT + assert msg['type'] == TYPE_RESULT assert msg['success'] is False assert msg['error']['code'] == 'file_not_found' @@ -115,6 +115,6 @@ async def test_lovelace_ui_load_err(hass, hass_ws_client): msg = await client.receive_json() assert msg['id'] == 5 - assert msg['type'] == wapi.TYPE_RESULT + assert msg['type'] == TYPE_RESULT assert msg['success'] is False assert msg['error']['code'] == 'load_error' diff --git a/tests/components/media_player/test_init.py b/tests/components/media_player/test_init.py index 5d632d4de0b..808c6e4f50f 100644 --- a/tests/components/media_player/test_init.py +++ b/tests/components/media_player/test_init.py @@ -3,7 +3,7 @@ import base64 from unittest.mock import patch from homeassistant.setup import async_setup_component -from homeassistant.components import websocket_api +from homeassistant.components.websocket_api.const import TYPE_RESULT from tests.common import mock_coro @@ -30,7 +30,7 @@ async def test_get_panels(hass, hass_ws_client): msg = await client.receive_json() assert msg['id'] == 5 - assert msg['type'] == websocket_api.TYPE_RESULT + assert msg['type'] == TYPE_RESULT assert msg['success'] assert msg['result']['content_type'] == 'image/jpeg' assert msg['result']['content'] == \ diff --git a/tests/components/persistent_notification/test_init.py b/tests/components/persistent_notification/test_init.py index 6acc796a108..5df106a5327 100644 --- a/tests/components/persistent_notification/test_init.py +++ b/tests/components/persistent_notification/test_init.py @@ -1,5 +1,5 @@ """The tests for the persistent notification component.""" -from homeassistant.components import websocket_api +from homeassistant.components.websocket_api.const import TYPE_RESULT from homeassistant.setup import setup_component, async_setup_component import homeassistant.components.persistent_notification as pn @@ -151,7 +151,7 @@ async def test_ws_get_notifications(hass, hass_ws_client): }) msg = await client.receive_json() assert msg['id'] == 5 - assert msg['type'] == websocket_api.TYPE_RESULT + assert msg['type'] == TYPE_RESULT assert msg['success'] notifications = msg['result'] assert len(notifications) == 0 @@ -165,7 +165,7 @@ async def test_ws_get_notifications(hass, hass_ws_client): }) msg = await client.receive_json() assert msg['id'] == 6 - assert msg['type'] == websocket_api.TYPE_RESULT + assert msg['type'] == TYPE_RESULT assert msg['success'] notifications = msg['result'] assert len(notifications) == 1 diff --git a/tests/components/test_websocket_api.py b/tests/components/test_websocket_api.py deleted file mode 100644 index cf74081adb1..00000000000 --- a/tests/components/test_websocket_api.py +++ /dev/null @@ -1,558 +0,0 @@ -"""Tests for the Home Assistant Websocket API.""" -import asyncio -from unittest.mock import patch, Mock - -from aiohttp import WSMsgType -from async_timeout import timeout -import pytest - -from homeassistant.core import callback -from homeassistant.components import websocket_api as wapi -from homeassistant.setup import async_setup_component - -from tests.common import mock_coro, async_mock_service - -API_PASSWORD = 'test1234' - - -@pytest.fixture -def websocket_client(hass, hass_ws_client): - """Create a websocket client.""" - return hass.loop.run_until_complete(hass_ws_client(hass)) - - -@pytest.fixture -def no_auth_websocket_client(hass, loop, aiohttp_client): - """Websocket connection that requires authentication.""" - assert loop.run_until_complete( - async_setup_component(hass, 'websocket_api', { - 'http': { - 'api_password': API_PASSWORD - } - })) - - client = loop.run_until_complete(aiohttp_client(hass.http.app)) - ws = loop.run_until_complete(client.ws_connect(wapi.URL)) - - auth_ok = loop.run_until_complete(ws.receive_json()) - assert auth_ok['type'] == wapi.TYPE_AUTH_REQUIRED - - yield ws - - if not ws.closed: - loop.run_until_complete(ws.close()) - - -@pytest.fixture -def mock_low_queue(): - """Mock a low queue.""" - with patch.object(wapi, 'MAX_PENDING_MSG', 5): - yield - - -@asyncio.coroutine -def test_auth_via_msg(no_auth_websocket_client): - """Test authenticating.""" - yield from no_auth_websocket_client.send_json({ - 'type': wapi.TYPE_AUTH, - 'api_password': API_PASSWORD - }) - - msg = yield from no_auth_websocket_client.receive_json() - - assert msg['type'] == wapi.TYPE_AUTH_OK - - -@asyncio.coroutine -def test_auth_via_msg_incorrect_pass(no_auth_websocket_client): - """Test authenticating.""" - with patch('homeassistant.components.websocket_api.process_wrong_login', - return_value=mock_coro()) as mock_process_wrong_login: - yield from no_auth_websocket_client.send_json({ - 'type': wapi.TYPE_AUTH, - 'api_password': API_PASSWORD + 'wrong' - }) - - msg = yield from no_auth_websocket_client.receive_json() - - assert mock_process_wrong_login.called - assert msg['type'] == wapi.TYPE_AUTH_INVALID - assert msg['message'] == 'Invalid access token or password' - - -@asyncio.coroutine -def test_pre_auth_only_auth_allowed(no_auth_websocket_client): - """Verify that before authentication, only auth messages are allowed.""" - yield from no_auth_websocket_client.send_json({ - 'type': wapi.TYPE_CALL_SERVICE, - 'domain': 'domain_test', - 'service': 'test_service', - 'service_data': { - 'hello': 'world' - } - }) - - msg = yield from no_auth_websocket_client.receive_json() - - assert msg['type'] == wapi.TYPE_AUTH_INVALID - assert msg['message'].startswith('Message incorrectly formatted') - - -@asyncio.coroutine -def test_invalid_message_format(websocket_client): - """Test sending invalid JSON.""" - yield from websocket_client.send_json({'type': 5}) - - msg = yield from websocket_client.receive_json() - - assert msg['type'] == wapi.TYPE_RESULT - error = msg['error'] - assert error['code'] == wapi.ERR_INVALID_FORMAT - assert error['message'].startswith('Message incorrectly formatted') - - -@asyncio.coroutine -def test_invalid_json(websocket_client): - """Test sending invalid JSON.""" - yield from websocket_client.send_str('this is not JSON') - - msg = yield from websocket_client.receive() - - assert msg.type == WSMsgType.close - - -@asyncio.coroutine -def test_quiting_hass(hass, websocket_client): - """Test sending invalid JSON.""" - with patch.object(hass.loop, 'stop'): - yield from hass.async_stop() - - msg = yield from websocket_client.receive() - - assert msg.type == WSMsgType.CLOSE - - -@asyncio.coroutine -def test_call_service(hass, websocket_client): - """Test call service command.""" - calls = [] - - @callback - def service_call(call): - calls.append(call) - - hass.services.async_register('domain_test', 'test_service', service_call) - - yield from websocket_client.send_json({ - 'id': 5, - 'type': wapi.TYPE_CALL_SERVICE, - 'domain': 'domain_test', - 'service': 'test_service', - 'service_data': { - 'hello': 'world' - } - }) - - msg = yield from websocket_client.receive_json() - assert msg['id'] == 5 - assert msg['type'] == wapi.TYPE_RESULT - assert msg['success'] - - assert len(calls) == 1 - call = calls[0] - - assert call.domain == 'domain_test' - assert call.service == 'test_service' - assert call.data == {'hello': 'world'} - - -@asyncio.coroutine -def test_subscribe_unsubscribe_events(hass, websocket_client): - """Test subscribe/unsubscribe events command.""" - init_count = sum(hass.bus.async_listeners().values()) - - yield from websocket_client.send_json({ - 'id': 5, - 'type': wapi.TYPE_SUBSCRIBE_EVENTS, - 'event_type': 'test_event' - }) - - msg = yield from websocket_client.receive_json() - assert msg['id'] == 5 - assert msg['type'] == wapi.TYPE_RESULT - assert msg['success'] - - # Verify we have a new listener - assert sum(hass.bus.async_listeners().values()) == init_count + 1 - - hass.bus.async_fire('ignore_event') - hass.bus.async_fire('test_event', {'hello': 'world'}) - hass.bus.async_fire('ignore_event') - - with timeout(3, loop=hass.loop): - msg = yield from websocket_client.receive_json() - - assert msg['id'] == 5 - assert msg['type'] == wapi.TYPE_EVENT - event = msg['event'] - - assert event['event_type'] == 'test_event' - assert event['data'] == {'hello': 'world'} - assert event['origin'] == 'LOCAL' - - yield from websocket_client.send_json({ - 'id': 6, - 'type': wapi.TYPE_UNSUBSCRIBE_EVENTS, - 'subscription': 5 - }) - - msg = yield from websocket_client.receive_json() - assert msg['id'] == 6 - assert msg['type'] == wapi.TYPE_RESULT - assert msg['success'] - - # Check our listener got unsubscribed - assert sum(hass.bus.async_listeners().values()) == init_count - - -@asyncio.coroutine -def test_get_states(hass, websocket_client): - """Test get_states command.""" - hass.states.async_set('greeting.hello', 'world') - hass.states.async_set('greeting.bye', 'universe') - - yield from websocket_client.send_json({ - 'id': 5, - 'type': wapi.TYPE_GET_STATES, - }) - - msg = yield from websocket_client.receive_json() - assert msg['id'] == 5 - assert msg['type'] == wapi.TYPE_RESULT - assert msg['success'] - - states = [] - for state in hass.states.async_all(): - state = state.as_dict() - state['last_changed'] = state['last_changed'].isoformat() - state['last_updated'] = state['last_updated'].isoformat() - states.append(state) - - assert msg['result'] == states - - -@asyncio.coroutine -def test_get_services(hass, websocket_client): - """Test get_services command.""" - yield from websocket_client.send_json({ - 'id': 5, - 'type': wapi.TYPE_GET_SERVICES, - }) - - msg = yield from websocket_client.receive_json() - assert msg['id'] == 5 - assert msg['type'] == wapi.TYPE_RESULT - assert msg['success'] - assert msg['result'] == hass.services.async_services() - - -@asyncio.coroutine -def test_get_config(hass, websocket_client): - """Test get_config command.""" - yield from websocket_client.send_json({ - 'id': 5, - 'type': wapi.TYPE_GET_CONFIG, - }) - - msg = yield from websocket_client.receive_json() - assert msg['id'] == 5 - assert msg['type'] == wapi.TYPE_RESULT - assert msg['success'] - - if 'components' in msg['result']: - msg['result']['components'] = set(msg['result']['components']) - if 'whitelist_external_dirs' in msg['result']: - msg['result']['whitelist_external_dirs'] = \ - set(msg['result']['whitelist_external_dirs']) - - assert msg['result'] == hass.config.as_dict() - - -@asyncio.coroutine -def test_ping(websocket_client): - """Test get_panels command.""" - yield from websocket_client.send_json({ - 'id': 5, - 'type': wapi.TYPE_PING, - }) - - msg = yield from websocket_client.receive_json() - assert msg['id'] == 5 - assert msg['type'] == wapi.TYPE_PONG - - -@asyncio.coroutine -def test_pending_msg_overflow(hass, mock_low_queue, websocket_client): - """Test get_panels command.""" - for idx in range(10): - yield from websocket_client.send_json({ - 'id': idx + 1, - 'type': wapi.TYPE_PING, - }) - msg = yield from websocket_client.receive() - assert msg.type == WSMsgType.close - - -@asyncio.coroutine -def test_unknown_command(websocket_client): - """Test get_panels command.""" - yield from websocket_client.send_json({ - 'id': 5, - 'type': 'unknown_command', - }) - - msg = yield from websocket_client.receive_json() - assert not msg['success'] - assert msg['error']['code'] == wapi.ERR_UNKNOWN_COMMAND - - -async def test_auth_active_with_token(hass, aiohttp_client, hass_access_token): - """Test authenticating with a token.""" - assert await async_setup_component(hass, 'websocket_api', { - 'http': { - 'api_password': API_PASSWORD - } - }) - - client = await aiohttp_client(hass.http.app) - - async with client.ws_connect(wapi.URL) as ws: - with patch('homeassistant.auth.AuthManager.active') as auth_active: - auth_active.return_value = True - auth_msg = await ws.receive_json() - assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED - - await ws.send_json({ - 'type': wapi.TYPE_AUTH, - 'access_token': hass_access_token - }) - - auth_msg = await ws.receive_json() - assert auth_msg['type'] == wapi.TYPE_AUTH_OK - - -async def test_auth_active_user_inactive(hass, aiohttp_client, - hass_access_token): - """Test authenticating with a token.""" - refresh_token = await hass.auth.async_validate_access_token( - hass_access_token) - refresh_token.user.is_active = False - assert await async_setup_component(hass, 'websocket_api', { - 'http': { - 'api_password': API_PASSWORD - } - }) - - client = await aiohttp_client(hass.http.app) - - async with client.ws_connect(wapi.URL) as ws: - with patch('homeassistant.auth.AuthManager.active') as auth_active: - auth_active.return_value = True - auth_msg = await ws.receive_json() - assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED - - await ws.send_json({ - 'type': wapi.TYPE_AUTH, - 'access_token': hass_access_token - }) - - auth_msg = await ws.receive_json() - assert auth_msg['type'] == wapi.TYPE_AUTH_INVALID - - -async def test_auth_active_with_password_not_allow(hass, aiohttp_client): - """Test authenticating with a token.""" - assert await async_setup_component(hass, 'websocket_api', { - 'http': { - 'api_password': API_PASSWORD - } - }) - - client = await aiohttp_client(hass.http.app) - - async with client.ws_connect(wapi.URL) as ws: - with patch('homeassistant.auth.AuthManager.active', - return_value=True): - auth_msg = await ws.receive_json() - assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED - - await ws.send_json({ - 'type': wapi.TYPE_AUTH, - 'api_password': API_PASSWORD - }) - - auth_msg = await ws.receive_json() - assert auth_msg['type'] == wapi.TYPE_AUTH_INVALID - - -async def test_auth_legacy_support_with_password(hass, aiohttp_client): - """Test authenticating with a token.""" - assert await async_setup_component(hass, 'websocket_api', { - 'http': { - 'api_password': API_PASSWORD - } - }) - - client = await aiohttp_client(hass.http.app) - - async with client.ws_connect(wapi.URL) as ws: - with patch('homeassistant.auth.AuthManager.active', - return_value=True),\ - patch('homeassistant.auth.AuthManager.support_legacy', - return_value=True): - auth_msg = await ws.receive_json() - assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED - - await ws.send_json({ - 'type': wapi.TYPE_AUTH, - 'api_password': API_PASSWORD - }) - - auth_msg = await ws.receive_json() - assert auth_msg['type'] == wapi.TYPE_AUTH_OK - - -async def test_auth_with_invalid_token(hass, aiohttp_client): - """Test authenticating with a token.""" - assert await async_setup_component(hass, 'websocket_api', { - 'http': { - 'api_password': API_PASSWORD - } - }) - - client = await aiohttp_client(hass.http.app) - - async with client.ws_connect(wapi.URL) as ws: - with patch('homeassistant.auth.AuthManager.active') as auth_active: - auth_active.return_value = True - auth_msg = await ws.receive_json() - assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED - - await ws.send_json({ - 'type': wapi.TYPE_AUTH, - 'access_token': 'incorrect' - }) - - auth_msg = await ws.receive_json() - assert auth_msg['type'] == wapi.TYPE_AUTH_INVALID - - -async def test_call_service_context_with_user(hass, aiohttp_client, - hass_access_token): - """Test that the user is set in the service call context.""" - assert await async_setup_component(hass, 'websocket_api', { - 'http': { - 'api_password': API_PASSWORD - } - }) - - calls = async_mock_service(hass, 'domain_test', 'test_service') - client = await aiohttp_client(hass.http.app) - - async with client.ws_connect(wapi.URL) as ws: - with patch('homeassistant.auth.AuthManager.active') as auth_active: - auth_active.return_value = True - auth_msg = await ws.receive_json() - assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED - - await ws.send_json({ - 'type': wapi.TYPE_AUTH, - 'access_token': hass_access_token - }) - - auth_msg = await ws.receive_json() - assert auth_msg['type'] == wapi.TYPE_AUTH_OK - - await ws.send_json({ - 'id': 5, - 'type': wapi.TYPE_CALL_SERVICE, - 'domain': 'domain_test', - 'service': 'test_service', - 'service_data': { - 'hello': 'world' - } - }) - - msg = await ws.receive_json() - assert msg['success'] - - refresh_token = await hass.auth.async_validate_access_token( - hass_access_token) - - assert len(calls) == 1 - call = calls[0] - assert call.domain == 'domain_test' - assert call.service == 'test_service' - assert call.data == {'hello': 'world'} - assert call.context.user_id == refresh_token.user.id - - -async def test_call_service_context_no_user(hass, aiohttp_client): - """Test that connection without user sets context.""" - assert await async_setup_component(hass, 'websocket_api', { - 'http': { - 'api_password': API_PASSWORD - } - }) - - calls = async_mock_service(hass, 'domain_test', 'test_service') - client = await aiohttp_client(hass.http.app) - - async with client.ws_connect(wapi.URL) as ws: - auth_msg = await ws.receive_json() - assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED - - await ws.send_json({ - 'type': wapi.TYPE_AUTH, - 'api_password': API_PASSWORD - }) - - auth_msg = await ws.receive_json() - assert auth_msg['type'] == wapi.TYPE_AUTH_OK - - await ws.send_json({ - 'id': 5, - 'type': wapi.TYPE_CALL_SERVICE, - 'domain': 'domain_test', - 'service': 'test_service', - 'service_data': { - 'hello': 'world' - } - }) - - msg = await ws.receive_json() - assert msg['success'] - - assert len(calls) == 1 - call = calls[0] - assert call.domain == 'domain_test' - assert call.service == 'test_service' - assert call.data == {'hello': 'world'} - assert call.context.user_id is None - - -async def test_handler_failing(hass, websocket_client): - """Test a command that raises.""" - hass.components.websocket_api.async_register_command( - 'bla', Mock(side_effect=TypeError), - wapi.BASE_COMMAND_MESSAGE_SCHEMA.extend({'type': 'bla'})) - await websocket_client.send_json({ - 'id': 5, - 'type': 'bla', - }) - - msg = await websocket_client.receive_json() - assert msg['id'] == 5 - assert msg['type'] == wapi.TYPE_RESULT - assert not msg['success'] - assert msg['error']['code'] == wapi.ERR_UNKNOWN_ERROR diff --git a/tests/components/websocket_api/__init__.py b/tests/components/websocket_api/__init__.py new file mode 100644 index 00000000000..c218c6165d4 --- /dev/null +++ b/tests/components/websocket_api/__init__.py @@ -0,0 +1,2 @@ +"""Tests for the websocket API.""" +API_PASSWORD = 'test1234' diff --git a/tests/components/websocket_api/conftest.py b/tests/components/websocket_api/conftest.py new file mode 100644 index 00000000000..063e0b43d1b --- /dev/null +++ b/tests/components/websocket_api/conftest.py @@ -0,0 +1,35 @@ +"""Fixtures for websocket tests.""" +import pytest + +from homeassistant.setup import async_setup_component +from homeassistant.components import websocket_api as wapi + +from . import API_PASSWORD + + +@pytest.fixture +def websocket_client(hass, hass_ws_client): + """Create a websocket client.""" + return hass.loop.run_until_complete(hass_ws_client(hass)) + + +@pytest.fixture +def no_auth_websocket_client(hass, loop, aiohttp_client): + """Websocket connection that requires authentication.""" + assert loop.run_until_complete( + async_setup_component(hass, 'websocket_api', { + 'http': { + 'api_password': API_PASSWORD + } + })) + + client = loop.run_until_complete(aiohttp_client(hass.http.app)) + ws = loop.run_until_complete(client.ws_connect(wapi.URL)) + + auth_ok = loop.run_until_complete(ws.receive_json()) + assert auth_ok['type'] == wapi.TYPE_AUTH_REQUIRED + + yield ws + + if not ws.closed: + loop.run_until_complete(ws.close()) diff --git a/tests/components/websocket_api/test_auth.py b/tests/components/websocket_api/test_auth.py new file mode 100644 index 00000000000..ee1de906fa1 --- /dev/null +++ b/tests/components/websocket_api/test_auth.py @@ -0,0 +1,186 @@ +"""Test auth of websocket API.""" +from unittest.mock import patch + +from homeassistant.components import websocket_api as wapi +from homeassistant.components.websocket_api import commands +from homeassistant.setup import async_setup_component + +from tests.common import mock_coro + +from . import API_PASSWORD + + +async def test_auth_via_msg(no_auth_websocket_client): + """Test authenticating.""" + await no_auth_websocket_client.send_json({ + 'type': wapi.TYPE_AUTH, + 'api_password': API_PASSWORD + }) + + msg = await no_auth_websocket_client.receive_json() + + assert msg['type'] == wapi.TYPE_AUTH_OK + + +async def test_auth_via_msg_incorrect_pass(no_auth_websocket_client): + """Test authenticating.""" + with patch('homeassistant.components.websocket_api.process_wrong_login', + return_value=mock_coro()) as mock_process_wrong_login: + await no_auth_websocket_client.send_json({ + 'type': wapi.TYPE_AUTH, + 'api_password': API_PASSWORD + 'wrong' + }) + + msg = await no_auth_websocket_client.receive_json() + + assert mock_process_wrong_login.called + assert msg['type'] == wapi.TYPE_AUTH_INVALID + assert msg['message'] == 'Invalid access token or password' + + +async def test_pre_auth_only_auth_allowed(no_auth_websocket_client): + """Verify that before authentication, only auth messages are allowed.""" + await no_auth_websocket_client.send_json({ + 'type': commands.TYPE_CALL_SERVICE, + 'domain': 'domain_test', + 'service': 'test_service', + 'service_data': { + 'hello': 'world' + } + }) + + msg = await no_auth_websocket_client.receive_json() + + assert msg['type'] == wapi.TYPE_AUTH_INVALID + assert msg['message'].startswith('Message incorrectly formatted') + + +async def test_auth_active_with_token(hass, aiohttp_client, hass_access_token): + """Test authenticating with a token.""" + assert await async_setup_component(hass, 'websocket_api', { + 'http': { + 'api_password': API_PASSWORD + } + }) + + client = await aiohttp_client(hass.http.app) + + async with client.ws_connect(wapi.URL) as ws: + with patch('homeassistant.auth.AuthManager.active') as auth_active: + auth_active.return_value = True + auth_msg = await ws.receive_json() + assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED + + await ws.send_json({ + 'type': wapi.TYPE_AUTH, + 'access_token': hass_access_token + }) + + auth_msg = await ws.receive_json() + assert auth_msg['type'] == wapi.TYPE_AUTH_OK + + +async def test_auth_active_user_inactive(hass, aiohttp_client, + hass_access_token): + """Test authenticating with a token.""" + refresh_token = await hass.auth.async_validate_access_token( + hass_access_token) + refresh_token.user.is_active = False + assert await async_setup_component(hass, 'websocket_api', { + 'http': { + 'api_password': API_PASSWORD + } + }) + + client = await aiohttp_client(hass.http.app) + + async with client.ws_connect(wapi.URL) as ws: + with patch('homeassistant.auth.AuthManager.active') as auth_active: + auth_active.return_value = True + auth_msg = await ws.receive_json() + assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED + + await ws.send_json({ + 'type': wapi.TYPE_AUTH, + 'access_token': hass_access_token + }) + + auth_msg = await ws.receive_json() + assert auth_msg['type'] == wapi.TYPE_AUTH_INVALID + + +async def test_auth_active_with_password_not_allow(hass, aiohttp_client): + """Test authenticating with a token.""" + assert await async_setup_component(hass, 'websocket_api', { + 'http': { + 'api_password': API_PASSWORD + } + }) + + client = await aiohttp_client(hass.http.app) + + async with client.ws_connect(wapi.URL) as ws: + with patch('homeassistant.auth.AuthManager.active', + return_value=True): + auth_msg = await ws.receive_json() + assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED + + await ws.send_json({ + 'type': wapi.TYPE_AUTH, + 'api_password': API_PASSWORD + }) + + auth_msg = await ws.receive_json() + assert auth_msg['type'] == wapi.TYPE_AUTH_INVALID + + +async def test_auth_legacy_support_with_password(hass, aiohttp_client): + """Test authenticating with a token.""" + assert await async_setup_component(hass, 'websocket_api', { + 'http': { + 'api_password': API_PASSWORD + } + }) + + client = await aiohttp_client(hass.http.app) + + async with client.ws_connect(wapi.URL) as ws: + with patch('homeassistant.auth.AuthManager.active', + return_value=True),\ + patch('homeassistant.auth.AuthManager.support_legacy', + return_value=True): + auth_msg = await ws.receive_json() + assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED + + await ws.send_json({ + 'type': wapi.TYPE_AUTH, + 'api_password': API_PASSWORD + }) + + auth_msg = await ws.receive_json() + assert auth_msg['type'] == wapi.TYPE_AUTH_OK + + +async def test_auth_with_invalid_token(hass, aiohttp_client): + """Test authenticating with a token.""" + assert await async_setup_component(hass, 'websocket_api', { + 'http': { + 'api_password': API_PASSWORD + } + }) + + client = await aiohttp_client(hass.http.app) + + async with client.ws_connect(wapi.URL) as ws: + with patch('homeassistant.auth.AuthManager.active') as auth_active: + auth_active.return_value = True + auth_msg = await ws.receive_json() + assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED + + await ws.send_json({ + 'type': wapi.TYPE_AUTH, + 'access_token': 'incorrect' + }) + + auth_msg = await ws.receive_json() + assert auth_msg['type'] == wapi.TYPE_AUTH_INVALID diff --git a/tests/components/websocket_api/test_commands.py b/tests/components/websocket_api/test_commands.py new file mode 100644 index 00000000000..0eaf215afaa --- /dev/null +++ b/tests/components/websocket_api/test_commands.py @@ -0,0 +1,260 @@ +"""Tests for WebSocket API commands.""" +from unittest.mock import patch + +from async_timeout import timeout + +from homeassistant.core import callback +from homeassistant.components import websocket_api as wapi +from homeassistant.components.websocket_api import const, commands +from homeassistant.setup import async_setup_component + +from tests.common import async_mock_service + +from . import API_PASSWORD + + +async def test_call_service(hass, websocket_client): + """Test call service command.""" + calls = [] + + @callback + def service_call(call): + calls.append(call) + + hass.services.async_register('domain_test', 'test_service', service_call) + + await websocket_client.send_json({ + 'id': 5, + 'type': commands.TYPE_CALL_SERVICE, + 'domain': 'domain_test', + 'service': 'test_service', + 'service_data': { + 'hello': 'world' + } + }) + + msg = await websocket_client.receive_json() + assert msg['id'] == 5 + assert msg['type'] == const.TYPE_RESULT + assert msg['success'] + + assert len(calls) == 1 + call = calls[0] + + assert call.domain == 'domain_test' + assert call.service == 'test_service' + assert call.data == {'hello': 'world'} + + +async def test_subscribe_unsubscribe_events(hass, websocket_client): + """Test subscribe/unsubscribe events command.""" + init_count = sum(hass.bus.async_listeners().values()) + + await websocket_client.send_json({ + 'id': 5, + 'type': commands.TYPE_SUBSCRIBE_EVENTS, + 'event_type': 'test_event' + }) + + msg = await websocket_client.receive_json() + assert msg['id'] == 5 + assert msg['type'] == const.TYPE_RESULT + assert msg['success'] + + # Verify we have a new listener + assert sum(hass.bus.async_listeners().values()) == init_count + 1 + + hass.bus.async_fire('ignore_event') + hass.bus.async_fire('test_event', {'hello': 'world'}) + hass.bus.async_fire('ignore_event') + + with timeout(3, loop=hass.loop): + msg = await websocket_client.receive_json() + + assert msg['id'] == 5 + assert msg['type'] == commands.TYPE_EVENT + event = msg['event'] + + assert event['event_type'] == 'test_event' + assert event['data'] == {'hello': 'world'} + assert event['origin'] == 'LOCAL' + + await websocket_client.send_json({ + 'id': 6, + 'type': commands.TYPE_UNSUBSCRIBE_EVENTS, + 'subscription': 5 + }) + + msg = await websocket_client.receive_json() + assert msg['id'] == 6 + assert msg['type'] == const.TYPE_RESULT + assert msg['success'] + + # Check our listener got unsubscribed + assert sum(hass.bus.async_listeners().values()) == init_count + + +async def test_get_states(hass, websocket_client): + """Test get_states command.""" + hass.states.async_set('greeting.hello', 'world') + hass.states.async_set('greeting.bye', 'universe') + + await websocket_client.send_json({ + 'id': 5, + 'type': commands.TYPE_GET_STATES, + }) + + msg = await websocket_client.receive_json() + assert msg['id'] == 5 + assert msg['type'] == const.TYPE_RESULT + assert msg['success'] + + states = [] + for state in hass.states.async_all(): + state = state.as_dict() + state['last_changed'] = state['last_changed'].isoformat() + state['last_updated'] = state['last_updated'].isoformat() + states.append(state) + + assert msg['result'] == states + + +async def test_get_services(hass, websocket_client): + """Test get_services command.""" + await websocket_client.send_json({ + 'id': 5, + 'type': commands.TYPE_GET_SERVICES, + }) + + msg = await websocket_client.receive_json() + assert msg['id'] == 5 + assert msg['type'] == const.TYPE_RESULT + assert msg['success'] + assert msg['result'] == hass.services.async_services() + + +async def test_get_config(hass, websocket_client): + """Test get_config command.""" + await websocket_client.send_json({ + 'id': 5, + 'type': commands.TYPE_GET_CONFIG, + }) + + msg = await websocket_client.receive_json() + assert msg['id'] == 5 + assert msg['type'] == const.TYPE_RESULT + assert msg['success'] + + if 'components' in msg['result']: + msg['result']['components'] = set(msg['result']['components']) + if 'whitelist_external_dirs' in msg['result']: + msg['result']['whitelist_external_dirs'] = \ + set(msg['result']['whitelist_external_dirs']) + + assert msg['result'] == hass.config.as_dict() + + +async def test_ping(websocket_client): + """Test get_panels command.""" + await websocket_client.send_json({ + 'id': 5, + 'type': commands.TYPE_PING, + }) + + msg = await websocket_client.receive_json() + assert msg['id'] == 5 + assert msg['type'] == commands.TYPE_PONG + + +async def test_call_service_context_with_user(hass, aiohttp_client, + hass_access_token): + """Test that the user is set in the service call context.""" + assert await async_setup_component(hass, 'websocket_api', { + 'http': { + 'api_password': API_PASSWORD + } + }) + + calls = async_mock_service(hass, 'domain_test', 'test_service') + client = await aiohttp_client(hass.http.app) + + async with client.ws_connect(wapi.URL) as ws: + with patch('homeassistant.auth.AuthManager.active') as auth_active: + auth_active.return_value = True + auth_msg = await ws.receive_json() + assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED + + await ws.send_json({ + 'type': wapi.TYPE_AUTH, + 'access_token': hass_access_token + }) + + auth_msg = await ws.receive_json() + assert auth_msg['type'] == wapi.TYPE_AUTH_OK + + await ws.send_json({ + 'id': 5, + 'type': commands.TYPE_CALL_SERVICE, + 'domain': 'domain_test', + 'service': 'test_service', + 'service_data': { + 'hello': 'world' + } + }) + + msg = await ws.receive_json() + assert msg['success'] + + refresh_token = await hass.auth.async_validate_access_token( + hass_access_token) + + assert len(calls) == 1 + call = calls[0] + assert call.domain == 'domain_test' + assert call.service == 'test_service' + assert call.data == {'hello': 'world'} + assert call.context.user_id == refresh_token.user.id + + +async def test_call_service_context_no_user(hass, aiohttp_client): + """Test that connection without user sets context.""" + assert await async_setup_component(hass, 'websocket_api', { + 'http': { + 'api_password': API_PASSWORD + } + }) + + calls = async_mock_service(hass, 'domain_test', 'test_service') + client = await aiohttp_client(hass.http.app) + + async with client.ws_connect(wapi.URL) as ws: + auth_msg = await ws.receive_json() + assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED + + await ws.send_json({ + 'type': wapi.TYPE_AUTH, + 'api_password': API_PASSWORD + }) + + auth_msg = await ws.receive_json() + assert auth_msg['type'] == wapi.TYPE_AUTH_OK + + await ws.send_json({ + 'id': 5, + 'type': commands.TYPE_CALL_SERVICE, + 'domain': 'domain_test', + 'service': 'test_service', + 'service_data': { + 'hello': 'world' + } + }) + + msg = await ws.receive_json() + assert msg['success'] + + assert len(calls) == 1 + call = calls[0] + assert call.domain == 'domain_test' + assert call.service == 'test_service' + assert call.data == {'hello': 'world'} + assert call.context.user_id is None diff --git a/tests/components/websocket_api/test_init.py b/tests/components/websocket_api/test_init.py new file mode 100644 index 00000000000..97acc1210fc --- /dev/null +++ b/tests/components/websocket_api/test_init.py @@ -0,0 +1,92 @@ +"""Tests for the Home Assistant Websocket API.""" +import asyncio +from unittest.mock import patch, Mock + +from aiohttp import WSMsgType +import pytest + +from homeassistant.components import websocket_api as wapi +from homeassistant.components.websocket_api import const, commands, messages + + +@pytest.fixture +def mock_low_queue(): + """Mock a low queue.""" + with patch.object(wapi, 'MAX_PENDING_MSG', 5): + yield + + +@asyncio.coroutine +def test_invalid_message_format(websocket_client): + """Test sending invalid JSON.""" + yield from websocket_client.send_json({'type': 5}) + + msg = yield from websocket_client.receive_json() + + assert msg['type'] == const.TYPE_RESULT + error = msg['error'] + assert error['code'] == const.ERR_INVALID_FORMAT + assert error['message'].startswith('Message incorrectly formatted') + + +@asyncio.coroutine +def test_invalid_json(websocket_client): + """Test sending invalid JSON.""" + yield from websocket_client.send_str('this is not JSON') + + msg = yield from websocket_client.receive() + + assert msg.type == WSMsgType.close + + +@asyncio.coroutine +def test_quiting_hass(hass, websocket_client): + """Test sending invalid JSON.""" + with patch.object(hass.loop, 'stop'): + yield from hass.async_stop() + + msg = yield from websocket_client.receive() + + assert msg.type == WSMsgType.CLOSE + + +@asyncio.coroutine +def test_pending_msg_overflow(hass, mock_low_queue, websocket_client): + """Test get_panels command.""" + for idx in range(10): + yield from websocket_client.send_json({ + 'id': idx + 1, + 'type': commands.TYPE_PING, + }) + msg = yield from websocket_client.receive() + assert msg.type == WSMsgType.close + + +@asyncio.coroutine +def test_unknown_command(websocket_client): + """Test get_panels command.""" + yield from websocket_client.send_json({ + 'id': 5, + 'type': 'unknown_command', + }) + + msg = yield from websocket_client.receive_json() + assert not msg['success'] + assert msg['error']['code'] == const.ERR_UNKNOWN_COMMAND + + +async def test_handler_failing(hass, websocket_client): + """Test a command that raises.""" + hass.components.websocket_api.async_register_command( + 'bla', Mock(side_effect=TypeError), + messages.BASE_COMMAND_MESSAGE_SCHEMA.extend({'type': 'bla'})) + await websocket_client.send_json({ + 'id': 5, + 'type': 'bla', + }) + + msg = await websocket_client.receive_json() + assert msg['id'] == 5 + assert msg['type'] == const.TYPE_RESULT + assert not msg['success'] + assert msg['error']['code'] == const.ERR_UNKNOWN_ERROR