From df21dd21f2c59ca4cb0ff338e122ab4513263f51 Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Fri, 30 Nov 2018 21:28:35 +0100 Subject: [PATCH] RFC: Call services directly (#18720) * Call services directly * Simplify * Type * Lint * Update name * Fix tests * Catch exceptions in HTTP view * Lint * Handle ServiceNotFound in API endpoints that call services * Type * Don't crash recorder on non-JSON serializable objects --- homeassistant/auth/mfa_modules/notify.py | 8 +- homeassistant/auth/providers/__init__.py | 6 +- homeassistant/components/api.py | 12 +- homeassistant/components/http/view.py | 8 +- homeassistant/components/mqtt_eventstream.py | 12 +- homeassistant/components/recorder/__init__.py | 22 ++- .../components/websocket_api/commands.py | 15 +- homeassistant/const.py | 4 - homeassistant/core.py | 134 ++++++------------ homeassistant/exceptions.py | 11 ++ tests/auth/mfa_modules/test_notify.py | 6 +- tests/components/climate/test_demo.py | 26 ++-- tests/components/climate/test_init.py | 8 +- tests/components/climate/test_mqtt.py | 12 +- tests/components/deconz/test_init.py | 15 +- tests/components/http/test_view.py | 50 ++++++- tests/components/media_player/test_demo.py | 15 +- .../components/media_player/test_monoprice.py | 2 +- tests/components/mqtt/test_init.py | 11 +- tests/components/notify/test_demo.py | 6 +- tests/components/test_alert.py | 1 + tests/components/test_api.py | 27 ++++ tests/components/test_input_datetime.py | 12 +- tests/components/test_logbook.py | 7 +- tests/components/test_snips.py | 17 ++- tests/components/test_wake_on_lan.py | 9 +- tests/components/water_heater/test_demo.py | 9 +- .../components/websocket_api/test_commands.py | 19 +++ tests/components/zwave/test_init.py | 2 +- tests/test_core.py | 12 +- 30 files changed, 312 insertions(+), 186 deletions(-) diff --git a/homeassistant/auth/mfa_modules/notify.py b/homeassistant/auth/mfa_modules/notify.py index 8eea3acb6ed..3c26f8b4bde 100644 --- a/homeassistant/auth/mfa_modules/notify.py +++ b/homeassistant/auth/mfa_modules/notify.py @@ -11,6 +11,7 @@ import voluptuous as vol from homeassistant.const import CONF_EXCLUDE, CONF_INCLUDE from homeassistant.core import HomeAssistant, callback +from homeassistant.exceptions import ServiceNotFound from homeassistant.helpers import config_validation as cv from . import MultiFactorAuthModule, MULTI_FACTOR_AUTH_MODULES, \ @@ -314,8 +315,11 @@ class NotifySetupFlow(SetupFlow): _generate_otp, self._secret, self._count) assert self._notify_service - await self._auth_module.async_notify( - code, self._notify_service, self._target) + try: + await self._auth_module.async_notify( + code, self._notify_service, self._target) + except ServiceNotFound: + return self.async_abort(reason='notify_service_not_exist') return self.async_show_form( step_id='setup', diff --git a/homeassistant/auth/providers/__init__.py b/homeassistant/auth/providers/__init__.py index 9ca4232b610..8828782c886 100644 --- a/homeassistant/auth/providers/__init__.py +++ b/homeassistant/auth/providers/__init__.py @@ -226,7 +226,11 @@ class LoginFlow(data_entry_flow.FlowHandler): if user_input is None and hasattr(auth_module, 'async_initialize_login_mfa_step'): - await auth_module.async_initialize_login_mfa_step(self.user.id) + try: + await auth_module.async_initialize_login_mfa_step(self.user.id) + except HomeAssistantError: + _LOGGER.exception('Error initializing MFA step') + return self.async_abort(reason='unknown_error') if user_input is not None: expires = self.created_at + MFA_SESSION_EXPIRATION diff --git a/homeassistant/components/api.py b/homeassistant/components/api.py index b001bcd0437..961350bfa89 100644 --- a/homeassistant/components/api.py +++ b/homeassistant/components/api.py @@ -9,7 +9,9 @@ import json import logging from aiohttp import web +from aiohttp.web_exceptions import HTTPBadRequest import async_timeout +import voluptuous as vol from homeassistant.bootstrap import DATA_LOGGING from homeassistant.components.http import HomeAssistantView @@ -21,7 +23,8 @@ from homeassistant.const import ( URL_API_TEMPLATE, __version__) import homeassistant.core as ha from homeassistant.auth.permissions.const import POLICY_READ -from homeassistant.exceptions import TemplateError, Unauthorized +from homeassistant.exceptions import ( + TemplateError, Unauthorized, ServiceNotFound) from homeassistant.helpers import template from homeassistant.helpers.service import async_get_all_descriptions from homeassistant.helpers.state import AsyncTrackStates @@ -339,8 +342,11 @@ class APIDomainServicesView(HomeAssistantView): "Data should be valid JSON.", HTTP_BAD_REQUEST) with AsyncTrackStates(hass) as changed_states: - await hass.services.async_call( - domain, service, data, True, self.context(request)) + try: + await hass.services.async_call( + domain, service, data, True, self.context(request)) + except (vol.Invalid, ServiceNotFound): + raise HTTPBadRequest() return self.json(changed_states) diff --git a/homeassistant/components/http/view.py b/homeassistant/components/http/view.py index c8f5d788dd2..beb5c647266 100644 --- a/homeassistant/components/http/view.py +++ b/homeassistant/components/http/view.py @@ -9,7 +9,9 @@ import json import logging from aiohttp import web -from aiohttp.web_exceptions import HTTPUnauthorized, HTTPInternalServerError +from aiohttp.web_exceptions import ( + HTTPUnauthorized, HTTPInternalServerError, HTTPBadRequest) +import voluptuous as vol from homeassistant.components.http.ban import process_success_login from homeassistant.core import Context, is_callback @@ -114,6 +116,10 @@ def request_handler_factory(view, handler): if asyncio.iscoroutine(result): result = await result + except vol.Invalid: + raise HTTPBadRequest() + except exceptions.ServiceNotFound: + raise HTTPInternalServerError() except exceptions.Unauthorized: raise HTTPUnauthorized() diff --git a/homeassistant/components/mqtt_eventstream.py b/homeassistant/components/mqtt_eventstream.py index 0e01310115f..2cde7825734 100644 --- a/homeassistant/components/mqtt_eventstream.py +++ b/homeassistant/components/mqtt_eventstream.py @@ -13,7 +13,7 @@ from homeassistant.core import callback from homeassistant.components.mqtt import ( valid_publish_topic, valid_subscribe_topic) from homeassistant.const import ( - ATTR_SERVICE_DATA, EVENT_CALL_SERVICE, EVENT_SERVICE_EXECUTED, + ATTR_SERVICE_DATA, EVENT_CALL_SERVICE, EVENT_STATE_CHANGED, EVENT_TIME_CHANGED, MATCH_ALL) from homeassistant.core import EventOrigin, State import homeassistant.helpers.config_validation as cv @@ -69,16 +69,6 @@ def async_setup(hass, config): ): return - # Filter out all the "event service executed" events because they - # are only used internally by core as callbacks for blocking - # during the interval while a service is being executed. - # They will serve no purpose to the external system, - # and thus are unnecessary traffic. - # And at any rate it would cause an infinite loop to publish them - # because publishing to an MQTT topic itself triggers one. - if event.event_type == EVENT_SERVICE_EXECUTED: - return - event_info = {'event_type': event.event_type, 'event_data': event.data} msg = json.dumps(event_info, cls=JSONEncoder) mqtt.async_publish(pub_topic, msg) diff --git a/homeassistant/components/recorder/__init__.py b/homeassistant/components/recorder/__init__.py index c53fa051a27..15de4c3f995 100644 --- a/homeassistant/components/recorder/__init__.py +++ b/homeassistant/components/recorder/__init__.py @@ -300,14 +300,24 @@ class Recorder(threading.Thread): time.sleep(CONNECT_RETRY_WAIT) try: with session_scope(session=self.get_session()) as session: - dbevent = Events.from_event(event) - session.add(dbevent) - session.flush() + try: + dbevent = Events.from_event(event) + session.add(dbevent) + session.flush() + except (TypeError, ValueError): + _LOGGER.warning( + "Event is not JSON serializable: %s", event) if event.event_type == EVENT_STATE_CHANGED: - dbstate = States.from_event(event) - dbstate.event_id = dbevent.event_id - session.add(dbstate) + try: + dbstate = States.from_event(event) + dbstate.event_id = dbevent.event_id + session.add(dbstate) + except (TypeError, ValueError): + _LOGGER.warning( + "State is not JSON serializable: %s", + event.data.get('new_state')) + updated = True except exc.OperationalError as err: diff --git a/homeassistant/components/websocket_api/commands.py b/homeassistant/components/websocket_api/commands.py index 53d1e9af807..ff928b43873 100644 --- a/homeassistant/components/websocket_api/commands.py +++ b/homeassistant/components/websocket_api/commands.py @@ -3,7 +3,7 @@ import voluptuous as vol from homeassistant.const import MATCH_ALL, EVENT_TIME_CHANGED from homeassistant.core import callback, DOMAIN as HASS_DOMAIN -from homeassistant.exceptions import Unauthorized +from homeassistant.exceptions import Unauthorized, ServiceNotFound from homeassistant.helpers import config_validation as cv from homeassistant.helpers.service import async_get_all_descriptions @@ -141,10 +141,15 @@ async def handle_call_service(hass, connection, msg): if (msg['domain'] == HASS_DOMAIN and msg['service'] in ['restart', 'stop']): blocking = False - await hass.services.async_call( - msg['domain'], msg['service'], msg.get('service_data'), blocking, - connection.context(msg)) - connection.send_message(messages.result_message(msg['id'])) + + try: + await hass.services.async_call( + msg['domain'], msg['service'], msg.get('service_data'), blocking, + connection.context(msg)) + connection.send_message(messages.result_message(msg['id'])) + except ServiceNotFound: + connection.send_message(messages.error_message( + msg['id'], const.ERR_NOT_FOUND, 'Service not found.')) @callback diff --git a/homeassistant/const.py b/homeassistant/const.py index fc97e1bc52d..eb53140339a 100644 --- a/homeassistant/const.py +++ b/homeassistant/const.py @@ -163,7 +163,6 @@ EVENT_HOMEASSISTANT_CLOSE = 'homeassistant_close' EVENT_STATE_CHANGED = 'state_changed' EVENT_TIME_CHANGED = 'time_changed' EVENT_CALL_SERVICE = 'call_service' -EVENT_SERVICE_EXECUTED = 'service_executed' EVENT_PLATFORM_DISCOVERED = 'platform_discovered' EVENT_COMPONENT_LOADED = 'component_loaded' EVENT_SERVICE_REGISTERED = 'service_registered' @@ -233,9 +232,6 @@ ATTR_ID = 'id' # Name ATTR_NAME = 'name' -# Data for a SERVICE_EXECUTED event -ATTR_SERVICE_CALL_ID = 'service_call_id' - # Contains one string or a list of strings, each being an entity id ATTR_ENTITY_ID = 'entity_id' diff --git a/homeassistant/core.py b/homeassistant/core.py index 1754a8b5014..2a40d604ee0 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -25,18 +25,18 @@ from typing import ( # noqa: F401 pylint: disable=unused-import from async_timeout import timeout import attr import voluptuous as vol -from voluptuous.humanize import humanize_error from homeassistant.const import ( ATTR_DOMAIN, ATTR_FRIENDLY_NAME, ATTR_NOW, ATTR_SERVICE, - ATTR_SERVICE_CALL_ID, ATTR_SERVICE_DATA, ATTR_SECONDS, EVENT_CALL_SERVICE, + ATTR_SERVICE_DATA, ATTR_SECONDS, EVENT_CALL_SERVICE, EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP, EVENT_HOMEASSISTANT_CLOSE, EVENT_SERVICE_REMOVED, - EVENT_SERVICE_EXECUTED, EVENT_SERVICE_REGISTERED, EVENT_STATE_CHANGED, + EVENT_SERVICE_REGISTERED, EVENT_STATE_CHANGED, EVENT_TIME_CHANGED, EVENT_TIMER_OUT_OF_SYNC, MATCH_ALL, __version__) from homeassistant import loader from homeassistant.exceptions import ( - HomeAssistantError, InvalidEntityFormatError, InvalidStateError) + HomeAssistantError, InvalidEntityFormatError, InvalidStateError, + Unauthorized, ServiceNotFound) from homeassistant.util.async_ import ( run_coroutine_threadsafe, run_callback_threadsafe, fire_coroutine_threadsafe) @@ -954,7 +954,6 @@ class ServiceRegistry: """Initialize a service registry.""" self._services = {} # type: Dict[str, Dict[str, Service]] self._hass = hass - self._async_unsub_call_event = None # type: Optional[CALLBACK_TYPE] @property def services(self) -> Dict[str, Dict[str, Service]]: @@ -1010,10 +1009,6 @@ class ServiceRegistry: else: self._services[domain] = {service: service_obj} - if self._async_unsub_call_event is None: - self._async_unsub_call_event = self._hass.bus.async_listen( - EVENT_CALL_SERVICE, self._event_to_service_call) - self._hass.bus.async_fire( EVENT_SERVICE_REGISTERED, {ATTR_DOMAIN: domain, ATTR_SERVICE: service} @@ -1092,100 +1087,61 @@ class ServiceRegistry: This method is a coroutine. """ + domain = domain.lower() + service = service.lower() context = context or Context() - call_id = uuid.uuid4().hex - event_data = { + service_data = service_data or {} + + try: + handler = self._services[domain][service] + except KeyError: + raise ServiceNotFound(domain, service) from None + + if handler.schema: + service_data = handler.schema(service_data) + + service_call = ServiceCall(domain, service, service_data, context) + + self._hass.bus.async_fire(EVENT_CALL_SERVICE, { ATTR_DOMAIN: domain.lower(), ATTR_SERVICE: service.lower(), ATTR_SERVICE_DATA: service_data, - ATTR_SERVICE_CALL_ID: call_id, - } + }) if not blocking: - self._hass.bus.async_fire( - EVENT_CALL_SERVICE, event_data, EventOrigin.local, context) + self._hass.async_create_task( + self._safe_execute(handler, service_call)) return None - fut = asyncio.Future() # type: asyncio.Future - - @callback - def service_executed(event: Event) -> None: - """Handle an executed service.""" - if event.data[ATTR_SERVICE_CALL_ID] == call_id: - fut.set_result(True) - unsub() - - unsub = self._hass.bus.async_listen( - EVENT_SERVICE_EXECUTED, service_executed) - - self._hass.bus.async_fire(EVENT_CALL_SERVICE, event_data, - EventOrigin.local, context) - - done, _ = await asyncio.wait([fut], timeout=SERVICE_CALL_LIMIT) - success = bool(done) - if not success: - unsub() - return success - - async def _event_to_service_call(self, event: Event) -> None: - """Handle the SERVICE_CALLED events from the EventBus.""" - service_data = event.data.get(ATTR_SERVICE_DATA) or {} - domain = event.data.get(ATTR_DOMAIN).lower() # type: ignore - service = event.data.get(ATTR_SERVICE).lower() # type: ignore - call_id = event.data.get(ATTR_SERVICE_CALL_ID) - - if not self.has_service(domain, service): - if event.origin == EventOrigin.local: - _LOGGER.warning("Unable to find service %s/%s", - domain, service) - return - - service_handler = self._services[domain][service] - - def fire_service_executed() -> None: - """Fire service executed event.""" - if not call_id: - return - - data = {ATTR_SERVICE_CALL_ID: call_id} - - if (service_handler.is_coroutinefunction or - service_handler.is_callback): - self._hass.bus.async_fire(EVENT_SERVICE_EXECUTED, data, - EventOrigin.local, event.context) - else: - self._hass.bus.fire(EVENT_SERVICE_EXECUTED, data, - EventOrigin.local, event.context) - try: - if service_handler.schema: - service_data = service_handler.schema(service_data) - except vol.Invalid as ex: - _LOGGER.error("Invalid service data for %s.%s: %s", - domain, service, humanize_error(service_data, ex)) - fire_service_executed() - return - - service_call = ServiceCall( - domain, service, service_data, event.context) + with timeout(SERVICE_CALL_LIMIT): + await asyncio.shield( + self._execute_service(handler, service_call)) + return True + except asyncio.TimeoutError: + return False + async def _safe_execute(self, handler: Service, + service_call: ServiceCall) -> None: + """Execute a service and catch exceptions.""" try: - if service_handler.is_callback: - service_handler.func(service_call) - fire_service_executed() - elif service_handler.is_coroutinefunction: - await service_handler.func(service_call) - fire_service_executed() - else: - def execute_service() -> None: - """Execute a service and fires a SERVICE_EXECUTED event.""" - service_handler.func(service_call) - fire_service_executed() - - await self._hass.async_add_executor_job(execute_service) + await self._execute_service(handler, service_call) + except Unauthorized: + _LOGGER.warning('Unauthorized service called %s/%s', + service_call.domain, service_call.service) except Exception: # pylint: disable=broad-except _LOGGER.exception('Error executing service %s', service_call) + async def _execute_service(self, handler: Service, + service_call: ServiceCall) -> None: + """Execute a service.""" + if handler.is_callback: + handler.func(service_call) + elif handler.is_coroutinefunction: + await handler.func(service_call) + else: + await self._hass.async_add_executor_job(handler.func, service_call) + class Config: """Configuration settings for Home Assistant.""" diff --git a/homeassistant/exceptions.py b/homeassistant/exceptions.py index 0613b7cb10c..5e2ab4988b1 100644 --- a/homeassistant/exceptions.py +++ b/homeassistant/exceptions.py @@ -58,3 +58,14 @@ class Unauthorized(HomeAssistantError): class UnknownUser(Unauthorized): """When call is made with user ID that doesn't exist.""" + + +class ServiceNotFound(HomeAssistantError): + """Raised when a service is not found.""" + + def __init__(self, domain: str, service: str) -> None: + """Initialize error.""" + super().__init__( + self, "Service {}.{} not found".format(domain, service)) + self.domain = domain + self.service = service diff --git a/tests/auth/mfa_modules/test_notify.py b/tests/auth/mfa_modules/test_notify.py index ffe0b103fc9..748b5507824 100644 --- a/tests/auth/mfa_modules/test_notify.py +++ b/tests/auth/mfa_modules/test_notify.py @@ -61,6 +61,7 @@ async def test_validating_mfa_counter(hass): 'counter': 0, 'notify_service': 'dummy', }) + async_mock_service(hass, 'notify', 'dummy') assert notify_auth_module._user_settings notify_setting = list(notify_auth_module._user_settings.values())[0] @@ -389,9 +390,8 @@ async def test_not_raise_exception_when_service_not_exist(hass): 'username': 'test-user', 'password': 'test-pass', }) - assert result['type'] == data_entry_flow.RESULT_TYPE_FORM - assert result['step_id'] == 'mfa' - assert result['data_schema'].schema.get('code') == str + assert result['type'] == data_entry_flow.RESULT_TYPE_ABORT + assert result['reason'] == 'unknown_error' # wait service call finished await hass.async_block_till_done() diff --git a/tests/components/climate/test_demo.py b/tests/components/climate/test_demo.py index 462939af23a..3a023916741 100644 --- a/tests/components/climate/test_demo.py +++ b/tests/components/climate/test_demo.py @@ -1,6 +1,9 @@ """The tests for the demo climate component.""" import unittest +import pytest +import voluptuous as vol + from homeassistant.util.unit_system import ( METRIC_SYSTEM ) @@ -57,7 +60,8 @@ class TestDemoClimate(unittest.TestCase): """Test setting the target temperature without required attribute.""" state = self.hass.states.get(ENTITY_CLIMATE) assert 21 == state.attributes.get('temperature') - common.set_temperature(self.hass, None, ENTITY_CLIMATE) + with pytest.raises(vol.Invalid): + common.set_temperature(self.hass, None, ENTITY_CLIMATE) self.hass.block_till_done() assert 21 == state.attributes.get('temperature') @@ -99,9 +103,11 @@ class TestDemoClimate(unittest.TestCase): assert state.attributes.get('temperature') is None assert 21.0 == state.attributes.get('target_temp_low') assert 24.0 == state.attributes.get('target_temp_high') - common.set_temperature(self.hass, temperature=None, - entity_id=ENTITY_ECOBEE, target_temp_low=None, - target_temp_high=None) + with pytest.raises(vol.Invalid): + common.set_temperature(self.hass, temperature=None, + entity_id=ENTITY_ECOBEE, + target_temp_low=None, + target_temp_high=None) self.hass.block_till_done() state = self.hass.states.get(ENTITY_ECOBEE) assert state.attributes.get('temperature') is None @@ -112,7 +118,8 @@ class TestDemoClimate(unittest.TestCase): """Test setting the target humidity without required attribute.""" state = self.hass.states.get(ENTITY_CLIMATE) assert 67 == state.attributes.get('humidity') - common.set_humidity(self.hass, None, ENTITY_CLIMATE) + with pytest.raises(vol.Invalid): + common.set_humidity(self.hass, None, ENTITY_CLIMATE) self.hass.block_till_done() state = self.hass.states.get(ENTITY_CLIMATE) assert 67 == state.attributes.get('humidity') @@ -130,7 +137,8 @@ class TestDemoClimate(unittest.TestCase): """Test setting fan mode without required attribute.""" state = self.hass.states.get(ENTITY_CLIMATE) assert "On High" == state.attributes.get('fan_mode') - common.set_fan_mode(self.hass, None, ENTITY_CLIMATE) + with pytest.raises(vol.Invalid): + common.set_fan_mode(self.hass, None, ENTITY_CLIMATE) self.hass.block_till_done() state = self.hass.states.get(ENTITY_CLIMATE) assert "On High" == state.attributes.get('fan_mode') @@ -148,7 +156,8 @@ class TestDemoClimate(unittest.TestCase): """Test setting swing mode without required attribute.""" state = self.hass.states.get(ENTITY_CLIMATE) assert "Off" == state.attributes.get('swing_mode') - common.set_swing_mode(self.hass, None, ENTITY_CLIMATE) + with pytest.raises(vol.Invalid): + common.set_swing_mode(self.hass, None, ENTITY_CLIMATE) self.hass.block_till_done() state = self.hass.states.get(ENTITY_CLIMATE) assert "Off" == state.attributes.get('swing_mode') @@ -170,7 +179,8 @@ class TestDemoClimate(unittest.TestCase): state = self.hass.states.get(ENTITY_CLIMATE) assert "cool" == state.attributes.get('operation_mode') assert "cool" == state.state - common.set_operation_mode(self.hass, None, ENTITY_CLIMATE) + with pytest.raises(vol.Invalid): + common.set_operation_mode(self.hass, None, ENTITY_CLIMATE) self.hass.block_till_done() state = self.hass.states.get(ENTITY_CLIMATE) assert "cool" == state.attributes.get('operation_mode') diff --git a/tests/components/climate/test_init.py b/tests/components/climate/test_init.py index 2e942c5988c..2aeb1228aba 100644 --- a/tests/components/climate/test_init.py +++ b/tests/components/climate/test_init.py @@ -1,6 +1,9 @@ """The tests for the climate component.""" import asyncio +import pytest +import voluptuous as vol + from homeassistant.components.climate import SET_TEMPERATURE_SCHEMA from tests.common import async_mock_service @@ -14,12 +17,11 @@ def test_set_temp_schema_no_req(hass, caplog): calls = async_mock_service(hass, domain, service, schema) data = {'operation_mode': 'test', 'entity_id': ['climate.test_id']} - yield from hass.services.async_call(domain, service, data) + with pytest.raises(vol.Invalid): + yield from hass.services.async_call(domain, service, data) yield from hass.async_block_till_done() assert len(calls) == 0 - assert 'ERROR' in caplog.text - assert 'Invalid service data' in caplog.text @asyncio.coroutine diff --git a/tests/components/climate/test_mqtt.py b/tests/components/climate/test_mqtt.py index 894fc290c38..7beb3887ae0 100644 --- a/tests/components/climate/test_mqtt.py +++ b/tests/components/climate/test_mqtt.py @@ -2,6 +2,9 @@ import unittest import copy +import pytest +import voluptuous as vol + from homeassistant.util.unit_system import ( METRIC_SYSTEM ) @@ -91,7 +94,8 @@ class TestMQTTClimate(unittest.TestCase): state = self.hass.states.get(ENTITY_CLIMATE) assert "off" == state.attributes.get('operation_mode') assert "off" == state.state - common.set_operation_mode(self.hass, None, ENTITY_CLIMATE) + with pytest.raises(vol.Invalid): + common.set_operation_mode(self.hass, None, ENTITY_CLIMATE) self.hass.block_till_done() state = self.hass.states.get(ENTITY_CLIMATE) assert "off" == state.attributes.get('operation_mode') @@ -177,7 +181,8 @@ class TestMQTTClimate(unittest.TestCase): state = self.hass.states.get(ENTITY_CLIMATE) assert "low" == state.attributes.get('fan_mode') - common.set_fan_mode(self.hass, None, ENTITY_CLIMATE) + with pytest.raises(vol.Invalid): + common.set_fan_mode(self.hass, None, ENTITY_CLIMATE) self.hass.block_till_done() state = self.hass.states.get(ENTITY_CLIMATE) assert "low" == state.attributes.get('fan_mode') @@ -225,7 +230,8 @@ class TestMQTTClimate(unittest.TestCase): state = self.hass.states.get(ENTITY_CLIMATE) assert "off" == state.attributes.get('swing_mode') - common.set_swing_mode(self.hass, None, ENTITY_CLIMATE) + with pytest.raises(vol.Invalid): + common.set_swing_mode(self.hass, None, ENTITY_CLIMATE) self.hass.block_till_done() state = self.hass.states.get(ENTITY_CLIMATE) assert "off" == state.attributes.get('swing_mode') diff --git a/tests/components/deconz/test_init.py b/tests/components/deconz/test_init.py index b83756f6ebb..5fa8ddcfe38 100644 --- a/tests/components/deconz/test_init.py +++ b/tests/components/deconz/test_init.py @@ -1,6 +1,9 @@ """Test deCONZ component setup process.""" from unittest.mock import Mock, patch +import pytest +import voluptuous as vol + from homeassistant.setup import async_setup_component from homeassistant.components import deconz @@ -163,11 +166,13 @@ async def test_service_configure(hass): await hass.async_block_till_done() # field does not start with / - with patch('pydeconz.DeconzSession.async_put_state', - return_value=mock_coro(True)): - await hass.services.async_call('deconz', 'configure', service_data={ - 'entity': 'light.test', 'field': 'state', 'data': data}) - await hass.async_block_till_done() + with pytest.raises(vol.Invalid): + with patch('pydeconz.DeconzSession.async_put_state', + return_value=mock_coro(True)): + await hass.services.async_call( + 'deconz', 'configure', service_data={ + 'entity': 'light.test', 'field': 'state', 'data': data}) + await hass.async_block_till_done() async def test_service_refresh_devices(hass): diff --git a/tests/components/http/test_view.py b/tests/components/http/test_view.py index ed97af9c764..395849f066e 100644 --- a/tests/components/http/test_view.py +++ b/tests/components/http/test_view.py @@ -1,8 +1,25 @@ """Tests for Home Assistant View.""" -from aiohttp.web_exceptions import HTTPInternalServerError -import pytest +from unittest.mock import Mock -from homeassistant.components.http.view import HomeAssistantView +from aiohttp.web_exceptions import ( + HTTPInternalServerError, HTTPBadRequest, HTTPUnauthorized) +import pytest +import voluptuous as vol + +from homeassistant.components.http.view import ( + HomeAssistantView, request_handler_factory) +from homeassistant.exceptions import ServiceNotFound, Unauthorized + +from tests.common import mock_coro_func + + +@pytest.fixture +def mock_request(): + """Mock a request.""" + return Mock( + app={'hass': Mock(is_running=True)}, + match_info={}, + ) async def test_invalid_json(caplog): @@ -13,3 +30,30 @@ async def test_invalid_json(caplog): view.json(float("NaN")) assert str(float("NaN")) in caplog.text + + +async def test_handling_unauthorized(mock_request): + """Test handling unauth exceptions.""" + with pytest.raises(HTTPUnauthorized): + await request_handler_factory( + Mock(requires_auth=False), + mock_coro_func(exception=Unauthorized) + )(mock_request) + + +async def test_handling_invalid_data(mock_request): + """Test handling unauth exceptions.""" + with pytest.raises(HTTPBadRequest): + await request_handler_factory( + Mock(requires_auth=False), + mock_coro_func(exception=vol.Invalid('yo')) + )(mock_request) + + +async def test_handling_service_not_found(mock_request): + """Test handling unauth exceptions.""" + with pytest.raises(HTTPInternalServerError): + await request_handler_factory( + Mock(requires_auth=False), + mock_coro_func(exception=ServiceNotFound('test', 'test')) + )(mock_request) diff --git a/tests/components/media_player/test_demo.py b/tests/components/media_player/test_demo.py index e986ac02065..b213cf0b5c1 100644 --- a/tests/components/media_player/test_demo.py +++ b/tests/components/media_player/test_demo.py @@ -3,6 +3,9 @@ import unittest from unittest.mock import patch import asyncio +import pytest +import voluptuous as vol + from homeassistant.setup import setup_component from homeassistant.const import HTTP_HEADER_HA_AUTH import homeassistant.components.media_player as mp @@ -43,7 +46,8 @@ class TestDemoMediaPlayer(unittest.TestCase): state = self.hass.states.get(entity_id) assert 'dvd' == state.attributes.get('source') - common.select_source(self.hass, None, entity_id) + with pytest.raises(vol.Invalid): + common.select_source(self.hass, None, entity_id) self.hass.block_till_done() state = self.hass.states.get(entity_id) assert 'dvd' == state.attributes.get('source') @@ -72,7 +76,8 @@ class TestDemoMediaPlayer(unittest.TestCase): state = self.hass.states.get(entity_id) assert 1.0 == state.attributes.get('volume_level') - common.set_volume_level(self.hass, None, entity_id) + with pytest.raises(vol.Invalid): + common.set_volume_level(self.hass, None, entity_id) self.hass.block_till_done() state = self.hass.states.get(entity_id) assert 1.0 == state.attributes.get('volume_level') @@ -201,7 +206,8 @@ class TestDemoMediaPlayer(unittest.TestCase): state.attributes.get('supported_features')) assert state.attributes.get('media_content_id') is not None - common.play_media(self.hass, None, 'some_id', ent_id) + with pytest.raises(vol.Invalid): + common.play_media(self.hass, None, 'some_id', ent_id) self.hass.block_till_done() state = self.hass.states.get(ent_id) assert 0 < (mp.SUPPORT_PLAY_MEDIA & @@ -216,7 +222,8 @@ class TestDemoMediaPlayer(unittest.TestCase): assert 'some_id' == state.attributes.get('media_content_id') assert not mock_seek.called - common.media_seek(self.hass, None, ent_id) + with pytest.raises(vol.Invalid): + common.media_seek(self.hass, None, ent_id) self.hass.block_till_done() assert not mock_seek.called common.media_seek(self.hass, 100, ent_id) diff --git a/tests/components/media_player/test_monoprice.py b/tests/components/media_player/test_monoprice.py index 417cd42187f..c6a6b3036d9 100644 --- a/tests/components/media_player/test_monoprice.py +++ b/tests/components/media_player/test_monoprice.py @@ -223,7 +223,7 @@ class TestMonopriceMediaPlayer(unittest.TestCase): # Restoring wrong media player to its previous state # Nothing should be done self.hass.services.call(DOMAIN, SERVICE_RESTORE, - {'entity_id': 'not_existing'}, + {'entity_id': 'media.not_existing'}, blocking=True) # self.hass.block_till_done() diff --git a/tests/components/mqtt/test_init.py b/tests/components/mqtt/test_init.py index 5d7afbde843..81e6a7b298d 100644 --- a/tests/components/mqtt/test_init.py +++ b/tests/components/mqtt/test_init.py @@ -113,11 +113,12 @@ class TestMQTTComponent(unittest.TestCase): """ payload = "not a template" payload_template = "a template" - self.hass.services.call(mqtt.DOMAIN, mqtt.SERVICE_PUBLISH, { - mqtt.ATTR_TOPIC: "test/topic", - mqtt.ATTR_PAYLOAD: payload, - mqtt.ATTR_PAYLOAD_TEMPLATE: payload_template - }, blocking=True) + with pytest.raises(vol.Invalid): + self.hass.services.call(mqtt.DOMAIN, mqtt.SERVICE_PUBLISH, { + mqtt.ATTR_TOPIC: "test/topic", + mqtt.ATTR_PAYLOAD: payload, + mqtt.ATTR_PAYLOAD_TEMPLATE: payload_template + }, blocking=True) assert not self.hass.data['mqtt'].async_publish.called def test_service_call_with_ascii_qos_retain_flags(self): diff --git a/tests/components/notify/test_demo.py b/tests/components/notify/test_demo.py index 57397e21ba2..4c3f3bf3f73 100644 --- a/tests/components/notify/test_demo.py +++ b/tests/components/notify/test_demo.py @@ -2,6 +2,9 @@ import unittest from unittest.mock import patch +import pytest +import voluptuous as vol + import homeassistant.components.notify as notify from homeassistant.setup import setup_component from homeassistant.components.notify import demo @@ -81,7 +84,8 @@ class TestNotifyDemo(unittest.TestCase): def test_sending_none_message(self): """Test send with None as message.""" self._setup_notify() - common.send_message(self.hass, None) + with pytest.raises(vol.Invalid): + common.send_message(self.hass, None) self.hass.block_till_done() assert len(self.events) == 0 diff --git a/tests/components/test_alert.py b/tests/components/test_alert.py index 76610421563..9fda58c37a3 100644 --- a/tests/components/test_alert.py +++ b/tests/components/test_alert.py @@ -99,6 +99,7 @@ class TestAlert(unittest.TestCase): def setUp(self): """Set up things to be run when tests are started.""" self.hass = get_test_home_assistant() + self._setup_notify() def tearDown(self): """Stop everything that was started.""" diff --git a/tests/components/test_api.py b/tests/components/test_api.py index 0bc89292855..a88c828efe8 100644 --- a/tests/components/test_api.py +++ b/tests/components/test_api.py @@ -6,6 +6,7 @@ from unittest.mock import patch from aiohttp import web import pytest +import voluptuous as vol from homeassistant import const from homeassistant.bootstrap import DATA_LOGGING @@ -578,3 +579,29 @@ async def test_rendering_template_legacy_user( json={"template": '{{ states.sensor.temperature.state }}'} ) assert resp.status == 401 + + +async def test_api_call_service_not_found(hass, mock_api_client): + """Test if the API failes 400 if unknown service.""" + resp = await mock_api_client.post( + const.URL_API_SERVICES_SERVICE.format( + "test_domain", "test_service")) + assert resp.status == 400 + + +async def test_api_call_service_bad_data(hass, mock_api_client): + """Test if the API failes 400 if unknown service.""" + test_value = [] + + @ha.callback + def listener(service_call): + """Record that our service got called.""" + test_value.append(1) + + hass.services.async_register("test_domain", "test_service", listener, + schema=vol.Schema({'hello': str})) + + resp = await mock_api_client.post( + const.URL_API_SERVICES_SERVICE.format( + "test_domain", "test_service"), json={'hello': 5}) + assert resp.status == 400 diff --git a/tests/components/test_input_datetime.py b/tests/components/test_input_datetime.py index a61cefe34f2..2a4d0fef09d 100644 --- a/tests/components/test_input_datetime.py +++ b/tests/components/test_input_datetime.py @@ -3,6 +3,9 @@ import asyncio import datetime +import pytest +import voluptuous as vol + from homeassistant.core import CoreState, State, Context from homeassistant.setup import async_setup_component from homeassistant.components.input_datetime import ( @@ -109,10 +112,11 @@ def test_set_invalid(hass): dt_obj = datetime.datetime(2017, 9, 7, 19, 46) time_portion = dt_obj.time() - yield from hass.services.async_call('input_datetime', 'set_datetime', { - 'entity_id': 'test_date', - 'time': time_portion - }) + with pytest.raises(vol.Invalid): + yield from hass.services.async_call('input_datetime', 'set_datetime', { + 'entity_id': 'test_date', + 'time': time_portion + }) yield from hass.async_block_till_done() state = hass.states.get(entity_id) diff --git a/tests/components/test_logbook.py b/tests/components/test_logbook.py index 5761ce8714b..6a272991798 100644 --- a/tests/components/test_logbook.py +++ b/tests/components/test_logbook.py @@ -4,6 +4,9 @@ import logging from datetime import (timedelta, datetime) import unittest +import pytest +import voluptuous as vol + from homeassistant.components import sun import homeassistant.core as ha from homeassistant.const import ( @@ -89,7 +92,9 @@ class TestComponentLogbook(unittest.TestCase): calls.append(event) self.hass.bus.listen(logbook.EVENT_LOGBOOK_ENTRY, event_listener) - self.hass.services.call(logbook.DOMAIN, 'log', {}, True) + + with pytest.raises(vol.Invalid): + self.hass.services.call(logbook.DOMAIN, 'log', {}, True) # Logbook entry service call results in firing an event. # Our service call will unblock when the event listeners have been diff --git a/tests/components/test_snips.py b/tests/components/test_snips.py index bc044999bdd..977cd966981 100644 --- a/tests/components/test_snips.py +++ b/tests/components/test_snips.py @@ -2,6 +2,9 @@ import json import logging +import pytest +import voluptuous as vol + from homeassistant.bootstrap import async_setup_component from homeassistant.components.mqtt import MQTT_PUBLISH_SCHEMA import homeassistant.components.snips as snips @@ -452,12 +455,11 @@ async def test_snips_say_invalid_config(hass, caplog): snips.SERVICE_SCHEMA_SAY) data = {'text': 'Hello', 'badKey': 'boo'} - await hass.services.async_call('snips', 'say', data) + with pytest.raises(vol.Invalid): + await hass.services.async_call('snips', 'say', data) await hass.async_block_till_done() assert len(calls) == 0 - assert 'ERROR' in caplog.text - assert 'Invalid service data' in caplog.text async def test_snips_say_action_invalid(hass, caplog): @@ -466,12 +468,12 @@ async def test_snips_say_action_invalid(hass, caplog): snips.SERVICE_SCHEMA_SAY_ACTION) data = {'text': 'Hello', 'can_be_enqueued': 'notabool'} - await hass.services.async_call('snips', 'say_action', data) + + with pytest.raises(vol.Invalid): + await hass.services.async_call('snips', 'say_action', data) await hass.async_block_till_done() assert len(calls) == 0 - assert 'ERROR' in caplog.text - assert 'Invalid service data' in caplog.text async def test_snips_feedback_on(hass, caplog): @@ -510,7 +512,8 @@ async def test_snips_feedback_config(hass, caplog): snips.SERVICE_SCHEMA_FEEDBACK) data = {'site_id': 'remote', 'test': 'test'} - await hass.services.async_call('snips', 'feedback_on', data) + with pytest.raises(vol.Invalid): + await hass.services.async_call('snips', 'feedback_on', data) await hass.async_block_till_done() assert len(calls) == 0 diff --git a/tests/components/test_wake_on_lan.py b/tests/components/test_wake_on_lan.py index abaf7dd6d14..cb9f05ba47b 100644 --- a/tests/components/test_wake_on_lan.py +++ b/tests/components/test_wake_on_lan.py @@ -3,6 +3,7 @@ import asyncio from unittest import mock import pytest +import voluptuous as vol from homeassistant.setup import async_setup_component from homeassistant.components.wake_on_lan import ( @@ -34,10 +35,10 @@ def test_send_magic_packet(hass, caplog, mock_wakeonlan): assert mock_wakeonlan.mock_calls[-1][1][0] == mac assert mock_wakeonlan.mock_calls[-1][2]['ip_address'] == bc_ip - yield from hass.services.async_call( - DOMAIN, SERVICE_SEND_MAGIC_PACKET, - {"broadcast_address": bc_ip}, blocking=True) - assert 'ERROR' in caplog.text + with pytest.raises(vol.Invalid): + yield from hass.services.async_call( + DOMAIN, SERVICE_SEND_MAGIC_PACKET, + {"broadcast_address": bc_ip}, blocking=True) assert len(mock_wakeonlan.mock_calls) == 1 yield from hass.services.async_call( diff --git a/tests/components/water_heater/test_demo.py b/tests/components/water_heater/test_demo.py index 66116db8cda..d8c9c71935b 100644 --- a/tests/components/water_heater/test_demo.py +++ b/tests/components/water_heater/test_demo.py @@ -1,6 +1,9 @@ """The tests for the demo water_heater component.""" import unittest +import pytest +import voluptuous as vol + from homeassistant.util.unit_system import ( IMPERIAL_SYSTEM ) @@ -48,7 +51,8 @@ class TestDemowater_heater(unittest.TestCase): """Test setting the target temperature without required attribute.""" state = self.hass.states.get(ENTITY_WATER_HEATER) assert 119 == state.attributes.get('temperature') - common.set_temperature(self.hass, None, ENTITY_WATER_HEATER) + with pytest.raises(vol.Invalid): + common.set_temperature(self.hass, None, ENTITY_WATER_HEATER) self.hass.block_till_done() assert 119 == state.attributes.get('temperature') @@ -69,7 +73,8 @@ class TestDemowater_heater(unittest.TestCase): state = self.hass.states.get(ENTITY_WATER_HEATER) assert "eco" == state.attributes.get('operation_mode') assert "eco" == state.state - common.set_operation_mode(self.hass, None, ENTITY_WATER_HEATER) + with pytest.raises(vol.Invalid): + common.set_operation_mode(self.hass, None, ENTITY_WATER_HEATER) self.hass.block_till_done() state = self.hass.states.get(ENTITY_WATER_HEATER) assert "eco" == state.attributes.get('operation_mode') diff --git a/tests/components/websocket_api/test_commands.py b/tests/components/websocket_api/test_commands.py index dc9d0318fd1..2406eefe08e 100644 --- a/tests/components/websocket_api/test_commands.py +++ b/tests/components/websocket_api/test_commands.py @@ -49,6 +49,25 @@ async def test_call_service(hass, websocket_client): assert call.data == {'hello': 'world'} +async def test_call_service_not_found(hass, websocket_client): + """Test call service command.""" + await websocket_client.send_json({ + 'id': 5, + 'type': commands.TYPE_CALL_SERVICE, + 'domain': 'domain_test', + 'service': 'test_service', + 'service_data': { + 'hello': 'world' + } + }) + + msg = await websocket_client.receive_json() + assert msg['id'] == 5 + assert msg['type'] == const.TYPE_RESULT + assert not msg['success'] + assert msg['error']['code'] == const.ERR_NOT_FOUND + + async def test_subscribe_unsubscribe_events(hass, websocket_client): """Test subscribe/unsubscribe events command.""" init_count = sum(hass.bus.async_listeners().values()) diff --git a/tests/components/zwave/test_init.py b/tests/components/zwave/test_init.py index d4077345649..85cca89eefc 100644 --- a/tests/components/zwave/test_init.py +++ b/tests/components/zwave/test_init.py @@ -947,7 +947,7 @@ class TestZWaveServices(unittest.TestCase): assert self.zwave_network.stop.called assert len(self.zwave_network.stop.mock_calls) == 1 assert mock_fire.called - assert len(mock_fire.mock_calls) == 2 + assert len(mock_fire.mock_calls) == 1 assert mock_fire.mock_calls[0][1][0] == const.EVENT_NETWORK_STOP def test_rename_node(self): diff --git a/tests/test_core.py b/tests/test_core.py index 69cde6c1403..724233cbf98 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -21,7 +21,7 @@ from homeassistant.const import ( __version__, EVENT_STATE_CHANGED, ATTR_FRIENDLY_NAME, CONF_UNIT_SYSTEM, ATTR_NOW, EVENT_TIME_CHANGED, EVENT_TIMER_OUT_OF_SYNC, ATTR_SECONDS, EVENT_HOMEASSISTANT_STOP, EVENT_HOMEASSISTANT_CLOSE, - EVENT_SERVICE_REGISTERED, EVENT_SERVICE_REMOVED, EVENT_SERVICE_EXECUTED) + EVENT_SERVICE_REGISTERED, EVENT_SERVICE_REMOVED) from tests.common import get_test_home_assistant, async_mock_service @@ -673,13 +673,8 @@ class TestServiceRegistry(unittest.TestCase): def test_call_non_existing_with_blocking(self): """Test non-existing with blocking.""" - prior = ha.SERVICE_CALL_LIMIT - try: - ha.SERVICE_CALL_LIMIT = 0.01 - assert not self.services.call('test_domain', 'i_do_not_exist', - blocking=True) - finally: - ha.SERVICE_CALL_LIMIT = prior + with pytest.raises(ha.ServiceNotFound): + self.services.call('test_domain', 'i_do_not_exist', blocking=True) def test_async_service(self): """Test registering and calling an async service.""" @@ -1005,4 +1000,3 @@ async def test_service_executed_with_subservices(hass): assert len(calls) == 4 assert [call.service for call in calls] == [ 'outer', 'inner', 'inner', 'outer'] - assert len(hass.bus.async_listeners().get(EVENT_SERVICE_EXECUTED, [])) == 0