From c7f4bdafc047f8d8a2430cd235bfa087b5395d1d Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen Date: Sun, 29 Jul 2018 01:53:37 +0100 Subject: [PATCH] Context (#15674) * Add context * Add context to switch/light services * Test set_state API * Lint * Fix tests * Do not include context yet in comparison * Do not pass in loop * Fix Z-Wave tests * Add websocket test without user --- homeassistant/components/api.py | 9 +- homeassistant/components/http/view.py | 10 +- homeassistant/components/light/__init__.py | 4 +- homeassistant/components/switch/__init__.py | 3 +- homeassistant/components/websocket_api.py | 19 ++- homeassistant/const.py | 3 - homeassistant/core.py | 178 ++++++++++++-------- homeassistant/helpers/entity.py | 4 +- tests/common.py | 2 +- tests/components/light/test_init.py | 27 ++- tests/components/switch/test_init.py | 25 ++- tests/components/test_api.py | 57 +++++++ tests/components/test_mqtt_eventstream.py | 6 +- tests/components/test_websocket_api.py | 93 +++++++++- tests/components/zwave/test_init.py | 8 +- tests/test_core.py | 24 +-- 16 files changed, 363 insertions(+), 109 deletions(-) diff --git a/homeassistant/components/api.py b/homeassistant/components/api.py index b80a5716061..de28eeff5ca 100644 --- a/homeassistant/components/api.py +++ b/homeassistant/components/api.py @@ -220,7 +220,8 @@ class APIEntityStateView(HomeAssistantView): is_new_state = hass.states.get(entity_id) is None # Write state - hass.states.async_set(entity_id, new_state, attributes, force_update) + hass.states.async_set(entity_id, new_state, attributes, force_update, + self.context(request)) # Read the state back for our response status_code = HTTP_CREATED if is_new_state else 200 @@ -279,7 +280,8 @@ class APIEventView(HomeAssistantView): event_data[key] = state request.app['hass'].bus.async_fire( - event_type, event_data, ha.EventOrigin.remote) + event_type, event_data, ha.EventOrigin.remote, + self.context(request)) return self.json_message("Event {} fired.".format(event_type)) @@ -316,7 +318,8 @@ class APIDomainServicesView(HomeAssistantView): "Data should be valid JSON.", HTTP_BAD_REQUEST) with AsyncTrackStates(hass) as changed_states: - await hass.services.async_call(domain, service, data, True) + await hass.services.async_call( + domain, service, data, True, self.context(request)) return self.json(changed_states) diff --git a/homeassistant/components/http/view.py b/homeassistant/components/http/view.py index 2b6c2a113c4..22ef34de54a 100644 --- a/homeassistant/components/http/view.py +++ b/homeassistant/components/http/view.py @@ -13,7 +13,7 @@ from aiohttp.web_exceptions import HTTPUnauthorized, HTTPInternalServerError import homeassistant.remote as rem from homeassistant.components.http.ban import process_success_login -from homeassistant.core import is_callback +from homeassistant.core import Context, is_callback from homeassistant.const import CONTENT_TYPE_JSON from .const import KEY_AUTHENTICATED, KEY_REAL_IP @@ -32,6 +32,14 @@ class HomeAssistantView: cors_allowed = False # pylint: disable=no-self-use + def context(self, request): + """Generate a context from a request.""" + user = request.get('hass_user') + if user is None: + return Context() + + return Context(user_id=user.id) + def json(self, result, status_code=200, headers=None): """Return a JSON response.""" try: diff --git a/homeassistant/components/light/__init__.py b/homeassistant/components/light/__init__.py index 58991a8e505..8b4b2137711 100644 --- a/homeassistant/components/light/__init__.py +++ b/homeassistant/components/light/__init__.py @@ -359,7 +359,9 @@ async def async_setup(hass, config): if not light.should_poll: continue - update_tasks.append(light.async_update_ha_state(True)) + + update_tasks.append( + light.async_update_ha_state(True, service.context)) if update_tasks: await asyncio.wait(update_tasks, loop=hass.loop) diff --git a/homeassistant/components/switch/__init__.py b/homeassistant/components/switch/__init__.py index b9ee8126ed3..cb69240ee73 100644 --- a/homeassistant/components/switch/__init__.py +++ b/homeassistant/components/switch/__init__.py @@ -114,7 +114,8 @@ async def async_setup(hass, config): if not switch.should_poll: continue - update_tasks.append(switch.async_update_ha_state(True)) + update_tasks.append( + switch.async_update_ha_state(True, service.context)) if update_tasks: await asyncio.wait(update_tasks, loop=hass.loop) diff --git a/homeassistant/components/websocket_api.py b/homeassistant/components/websocket_api.py index 98e3057338a..ed478550c7a 100644 --- a/homeassistant/components/websocket_api.py +++ b/homeassistant/components/websocket_api.py @@ -18,7 +18,7 @@ from voluptuous.humanize import humanize_error from homeassistant.const import ( MATCH_ALL, EVENT_TIME_CHANGED, EVENT_HOMEASSISTANT_STOP, __version__) -from homeassistant.core import callback +from homeassistant.core import Context, callback from homeassistant.loader import bind_hass from homeassistant.remote import JSONEncoder from homeassistant.helpers import config_validation as cv @@ -262,6 +262,18 @@ class ActiveConnection: self._handle_task = None self._writer_task = None + @property + def user(self): + """Return the user associated with the connection.""" + return self.request.get('hass_user') + + def context(self, msg): + """Return a context.""" + user = self.user + if user is None: + return Context() + return Context(user_id=user.id) + def debug(self, message1, message2=''): """Print a debug message.""" _LOGGER.debug("WS %s: %s %s", id(self.wsock), message1, message2) @@ -287,7 +299,7 @@ class ActiveConnection: @callback def send_message_outside(self, message): - """Send a message to the client outside of the main task. + """Send a message to the client. Closes connection if the client is not reading the messages. @@ -508,7 +520,8 @@ def handle_call_service(hass, connection, msg): async def call_service_helper(msg): """Call a service and fire complete message.""" await hass.services.async_call( - msg['domain'], msg['service'], msg.get('service_data'), True) + msg['domain'], msg['service'], msg.get('service_data'), True, + connection.context(msg)) connection.send_message_outside(result_message(msg['id'])) hass.async_add_job(call_service_helper(msg)) diff --git a/homeassistant/const.py b/homeassistant/const.py index a84c278350f..33a00b65533 100644 --- a/homeassistant/const.py +++ b/homeassistant/const.py @@ -224,9 +224,6 @@ ATTR_ID = 'id' # Name ATTR_NAME = 'name' -# Data for a SERVICE_EXECUTED event -ATTR_SERVICE_CALL_ID = 'service_call_id' - # Contains one string or a list of strings, each being an entity id ATTR_ENTITY_ID = 'entity_id' diff --git a/homeassistant/core.py b/homeassistant/core.py index 828dfc24d6c..b17df2c11fe 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -15,6 +15,7 @@ import re import sys import threading from time import monotonic +import uuid from types import MappingProxyType # pylint: disable=unused-import @@ -23,12 +24,13 @@ from typing import ( # NOQA TYPE_CHECKING, Awaitable, Iterator) from async_timeout import timeout +import attr import voluptuous as vol from voluptuous.humanize import humanize_error from homeassistant.const import ( ATTR_DOMAIN, ATTR_FRIENDLY_NAME, ATTR_NOW, ATTR_SERVICE, - ATTR_SERVICE_CALL_ID, ATTR_SERVICE_DATA, EVENT_CALL_SERVICE, + ATTR_SERVICE_DATA, EVENT_CALL_SERVICE, EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP, EVENT_SERVICE_EXECUTED, EVENT_SERVICE_REGISTERED, EVENT_STATE_CHANGED, EVENT_TIME_CHANGED, MATCH_ALL, EVENT_HOMEASSISTANT_CLOSE, @@ -191,7 +193,7 @@ class HomeAssistant: try: # Only block for EVENT_HOMEASSISTANT_START listener self.async_stop_track_tasks() - with timeout(TIMEOUT_EVENT_START, loop=self.loop): + with timeout(TIMEOUT_EVENT_START): await self.async_block_till_done() except asyncio.TimeoutError: _LOGGER.warning( @@ -201,7 +203,7 @@ class HomeAssistant: ', '.join(self.config.components)) # Allow automations to set up the start triggers before changing state - await asyncio.sleep(0, loop=self.loop) + await asyncio.sleep(0) self.state = CoreState.running _async_create_timer(self) @@ -307,16 +309,16 @@ class HomeAssistant: async def async_block_till_done(self) -> None: """Block till all pending work is done.""" # To flush out any call_soon_threadsafe - await asyncio.sleep(0, loop=self.loop) + await asyncio.sleep(0) while self._pending_tasks: pending = [task for task in self._pending_tasks if not task.done()] self._pending_tasks.clear() if pending: - await asyncio.wait(pending, loop=self.loop) + await asyncio.wait(pending) else: - await asyncio.sleep(0, loop=self.loop) + await asyncio.sleep(0) def stop(self) -> None: """Stop Home Assistant and shuts down all threads.""" @@ -343,6 +345,27 @@ class HomeAssistant: self.loop.stop() +@attr.s(slots=True, frozen=True) +class Context: + """The context that triggered something.""" + + user_id = attr.ib( + type=str, + default=None, + ) + id = attr.ib( + type=str, + default=attr.Factory(lambda: uuid.uuid4().hex), + ) + + def as_dict(self) -> dict: + """Return a dictionary representation of the context.""" + return { + 'id': self.id, + 'user_id': self.user_id, + } + + class EventOrigin(enum.Enum): """Represent the origin of an event.""" @@ -357,16 +380,18 @@ class EventOrigin(enum.Enum): class Event: """Representation of an event within the bus.""" - __slots__ = ['event_type', 'data', 'origin', 'time_fired'] + __slots__ = ['event_type', 'data', 'origin', 'time_fired', 'context'] def __init__(self, event_type: str, data: Optional[Dict] = None, origin: EventOrigin = EventOrigin.local, - time_fired: Optional[int] = None) -> None: + time_fired: Optional[int] = None, + context: Optional[Context] = None) -> None: """Initialize a new event.""" self.event_type = event_type self.data = data or {} self.origin = origin self.time_fired = time_fired or dt_util.utcnow() + self.context = context or Context() def as_dict(self) -> Dict: """Create a dict representation of this Event. @@ -378,6 +403,7 @@ class Event: 'data': dict(self.data), 'origin': str(self.origin), 'time_fired': self.time_fired, + 'context': self.context.as_dict() } def __repr__(self) -> str: @@ -425,14 +451,16 @@ class EventBus: ).result() def fire(self, event_type: str, event_data: Optional[Dict] = None, - origin: EventOrigin = EventOrigin.local) -> None: + origin: EventOrigin = EventOrigin.local, + context: Optional[Context] = None) -> None: """Fire an event.""" self._hass.loop.call_soon_threadsafe( - self.async_fire, event_type, event_data, origin) + self.async_fire, event_type, event_data, origin, context) @callback def async_fire(self, event_type: str, event_data: Optional[Dict] = None, - origin: EventOrigin = EventOrigin.local) -> None: + origin: EventOrigin = EventOrigin.local, + context: Optional[Context] = None) -> None: """Fire an event. This method must be run in the event loop. @@ -445,7 +473,7 @@ class EventBus: event_type != EVENT_HOMEASSISTANT_CLOSE): listeners = match_all_listeners + listeners - event = Event(event_type, event_data, origin) + event = Event(event_type, event_data, origin, None, context) if event_type != EVENT_TIME_CHANGED: _LOGGER.info("Bus:Handling %s", event) @@ -569,15 +597,17 @@ class State: attributes: extra information on entity and state last_changed: last time the state was changed, not the attributes. last_updated: last time this object was updated. + context: Context in which it was created """ __slots__ = ['entity_id', 'state', 'attributes', - 'last_changed', 'last_updated'] + 'last_changed', 'last_updated', 'context'] def __init__(self, entity_id: str, state: Any, attributes: Optional[Dict] = None, last_changed: Optional[datetime.datetime] = None, - last_updated: Optional[datetime.datetime] = None) -> None: + last_updated: Optional[datetime.datetime] = None, + context: Optional[Context] = None) -> None: """Initialize a new state.""" state = str(state) @@ -596,6 +626,7 @@ class State: self.attributes = MappingProxyType(attributes or {}) self.last_updated = last_updated or dt_util.utcnow() self.last_changed = last_changed or self.last_updated + self.context = context or Context() @property def domain(self) -> str: @@ -626,7 +657,8 @@ class State: 'state': self.state, 'attributes': dict(self.attributes), 'last_changed': self.last_changed, - 'last_updated': self.last_updated} + 'last_updated': self.last_updated, + 'context': self.context.as_dict()} @classmethod def from_dict(cls, json_dict: Dict) -> Any: @@ -650,8 +682,13 @@ class State: if isinstance(last_updated, str): last_updated = dt_util.parse_datetime(last_updated) + context = json_dict.get('context') + if context: + context = Context(**context) + return cls(json_dict['entity_id'], json_dict['state'], - json_dict.get('attributes'), last_changed, last_updated) + json_dict.get('attributes'), last_changed, last_updated, + context) def __eq__(self, other: Any) -> bool: """Return the comparison of the state.""" @@ -662,11 +699,11 @@ class State: def __repr__(self) -> str: """Return the representation of the states.""" - attr = "; {}".format(util.repr_helper(self.attributes)) \ - if self.attributes else "" + attrs = "; {}".format(util.repr_helper(self.attributes)) \ + if self.attributes else "" return "".format( - self.entity_id, self.state, attr, + self.entity_id, self.state, attrs, dt_util.as_local(self.last_changed).isoformat()) @@ -761,7 +798,8 @@ class StateMachine: def set(self, entity_id: str, new_state: Any, attributes: Optional[Dict] = None, - force_update: bool = False) -> None: + force_update: bool = False, + context: Optional[Context] = None) -> None: """Set the state of an entity, add entity if it does not exist. Attributes is an optional dict to specify attributes of this state. @@ -772,12 +810,14 @@ class StateMachine: run_callback_threadsafe( self._loop, self.async_set, entity_id, new_state, attributes, force_update, + context, ).result() @callback def async_set(self, entity_id: str, new_state: Any, attributes: Optional[Dict] = None, - force_update: bool = False) -> None: + force_update: bool = False, + context: Optional[Context] = None) -> None: """Set the state of an entity, add entity if it does not exist. Attributes is an optional dict to specify attributes of this state. @@ -804,13 +844,17 @@ class StateMachine: if same_state and same_attr: return - state = State(entity_id, new_state, attributes, last_changed) + if context is None: + context = Context() + + state = State(entity_id, new_state, attributes, last_changed, None, + context) self._states[entity_id] = state self._bus.async_fire(EVENT_STATE_CHANGED, { 'entity_id': entity_id, 'old_state': old_state, 'new_state': state, - }) + }, EventOrigin.local, context) class Service: @@ -818,7 +862,8 @@ class Service: __slots__ = ['func', 'schema', 'is_callback', 'is_coroutinefunction'] - def __init__(self, func: Callable, schema: Optional[vol.Schema]) -> None: + def __init__(self, func: Callable, schema: Optional[vol.Schema], + context: Optional[Context] = None) -> None: """Initialize a service.""" self.func = func self.schema = schema @@ -829,23 +874,25 @@ class Service: class ServiceCall: """Representation of a call to a service.""" - __slots__ = ['domain', 'service', 'data', 'call_id'] + __slots__ = ['domain', 'service', 'data', 'context'] def __init__(self, domain: str, service: str, data: Optional[Dict] = None, - call_id: Optional[str] = None) -> None: + context: Optional[Context] = None) -> None: """Initialize a service call.""" self.domain = domain.lower() self.service = service.lower() self.data = MappingProxyType(data or {}) - self.call_id = call_id + self.context = context or Context() def __repr__(self) -> str: """Return the representation of the service.""" if self.data: - return "".format( - self.domain, self.service, util.repr_helper(self.data)) + return "".format( + self.domain, self.service, self.context.id, + util.repr_helper(self.data)) - return "".format(self.domain, self.service) + return "".format( + self.domain, self.service, self.context.id) class ServiceRegistry: @@ -857,15 +904,6 @@ class ServiceRegistry: self._hass = hass self._async_unsub_call_event = None # type: Optional[CALLBACK_TYPE] - def _gen_unique_id() -> Iterator[str]: - cur_id = 1 - while True: - yield '{}-{}'.format(id(self), cur_id) - cur_id += 1 - - gen = _gen_unique_id() - self._generate_unique_id = lambda: next(gen) - @property def services(self) -> Dict[str, Dict[str, Service]]: """Return dictionary with per domain a list of available services.""" @@ -957,7 +995,8 @@ class ServiceRegistry: def call(self, domain: str, service: str, service_data: Optional[Dict] = None, - blocking: bool = False) -> Optional[bool]: + blocking: bool = False, + context: Optional[Context] = None) -> Optional[bool]: """ Call a service. @@ -975,13 +1014,14 @@ class ServiceRegistry: the keys ATTR_DOMAIN and ATTR_SERVICE in your service_data. """ return run_coroutine_threadsafe( # type: ignore - self.async_call(domain, service, service_data, blocking), + self.async_call(domain, service, service_data, blocking, context), self._hass.loop ).result() async def async_call(self, domain: str, service: str, service_data: Optional[Dict] = None, - blocking: bool = False) -> Optional[bool]: + blocking: bool = False, + context: Optional[Context] = None) -> Optional[bool]: """ Call a service. @@ -1000,44 +1040,42 @@ class ServiceRegistry: This method is a coroutine. """ - call_id = self._generate_unique_id() - + context = context or Context() event_data = { ATTR_DOMAIN: domain.lower(), ATTR_SERVICE: service.lower(), ATTR_SERVICE_DATA: service_data, - ATTR_SERVICE_CALL_ID: call_id, } - if blocking: - fut = asyncio.Future(loop=self._hass.loop) # type: asyncio.Future + if not blocking: + self._hass.bus.async_fire( + EVENT_CALL_SERVICE, event_data, EventOrigin.local, context) + return None - @callback - def service_executed(event: Event) -> None: - """Handle an executed service.""" - if event.data[ATTR_SERVICE_CALL_ID] == call_id: - fut.set_result(True) + fut = asyncio.Future() # type: asyncio.Future - unsub = self._hass.bus.async_listen( - EVENT_SERVICE_EXECUTED, service_executed) + @callback + def service_executed(event: Event) -> None: + """Handle an executed service.""" + if event.context == context: + fut.set_result(True) - self._hass.bus.async_fire(EVENT_CALL_SERVICE, event_data) + unsub = self._hass.bus.async_listen( + EVENT_SERVICE_EXECUTED, service_executed) - done, _ = await asyncio.wait( - [fut], loop=self._hass.loop, timeout=SERVICE_CALL_LIMIT) - success = bool(done) - unsub() - return success + self._hass.bus.async_fire(EVENT_CALL_SERVICE, event_data, + EventOrigin.local, context) - self._hass.bus.async_fire(EVENT_CALL_SERVICE, event_data) - return None + done, _ = await asyncio.wait([fut], timeout=SERVICE_CALL_LIMIT) + success = bool(done) + unsub() + return success async def _event_to_service_call(self, event: Event) -> None: """Handle the SERVICE_CALLED events from the EventBus.""" service_data = event.data.get(ATTR_SERVICE_DATA) or {} domain = event.data.get(ATTR_DOMAIN).lower() # type: ignore service = event.data.get(ATTR_SERVICE).lower() # type: ignore - call_id = event.data.get(ATTR_SERVICE_CALL_ID) if not self.has_service(domain, service): if event.origin == EventOrigin.local: @@ -1049,16 +1087,13 @@ class ServiceRegistry: def fire_service_executed() -> None: """Fire service executed event.""" - if not call_id: - return - - data = {ATTR_SERVICE_CALL_ID: call_id} - if (service_handler.is_coroutinefunction or service_handler.is_callback): - self._hass.bus.async_fire(EVENT_SERVICE_EXECUTED, data) + self._hass.bus.async_fire(EVENT_SERVICE_EXECUTED, {}, + EventOrigin.local, event.context) else: - self._hass.bus.fire(EVENT_SERVICE_EXECUTED, data) + self._hass.bus.fire(EVENT_SERVICE_EXECUTED, {}, + EventOrigin.local, event.context) try: if service_handler.schema: @@ -1069,7 +1104,8 @@ class ServiceRegistry: fire_service_executed() return - service_call = ServiceCall(domain, service, service_data, call_id) + service_call = ServiceCall( + domain, service, service_data, event.context) try: if service_handler.is_callback: diff --git a/homeassistant/helpers/entity.py b/homeassistant/helpers/entity.py index f466664fc61..c356c266db6 100644 --- a/homeassistant/helpers/entity.py +++ b/homeassistant/helpers/entity.py @@ -179,7 +179,7 @@ class Entity: # produce undesirable effects in the entity's operation. @asyncio.coroutine - def async_update_ha_state(self, force_refresh=False): + def async_update_ha_state(self, force_refresh=False, context=None): """Update Home Assistant with current state of entity. If force_refresh == True will update entity before setting state. @@ -279,7 +279,7 @@ class Entity: pass self.hass.states.async_set( - self.entity_id, state, attr, self.force_update) + self.entity_id, state, attr, self.force_update, context) def schedule_update_ha_state(self, force_refresh=False): """Schedule an update ha state change task. diff --git a/tests/common.py b/tests/common.py index 314799b185b..5567a431e58 100644 --- a/tests/common.py +++ b/tests/common.py @@ -187,7 +187,7 @@ def async_mock_service(hass, domain, service, schema=None): """Set up a fake service & return a calls log list to this service.""" calls = [] - @asyncio.coroutine + @ha.callback def mock_service_log(call): # pylint: disable=unnecessary-lambda """Mock service call.""" calls.append(call) diff --git a/tests/components/light/test_init.py b/tests/components/light/test_init.py index 74f8c85b532..4d779eef461 100644 --- a/tests/components/light/test_init.py +++ b/tests/components/light/test_init.py @@ -5,12 +5,12 @@ import unittest.mock as mock import os from io import StringIO -from homeassistant.setup import setup_component -import homeassistant.loader as loader +from homeassistant import core, loader +from homeassistant.setup import setup_component, async_setup_component from homeassistant.const import ( ATTR_ENTITY_ID, STATE_ON, STATE_OFF, CONF_PLATFORM, SERVICE_TURN_ON, SERVICE_TURN_OFF, SERVICE_TOGGLE, ATTR_SUPPORTED_FEATURES) -import homeassistant.components.light as light +from homeassistant.components import light from homeassistant.helpers.intent import IntentHandleError from tests.common import ( @@ -475,3 +475,24 @@ async def test_intent_set_color_and_brightness(hass): assert call.data.get(ATTR_ENTITY_ID) == 'light.hello_2' assert call.data.get(light.ATTR_RGB_COLOR) == (0, 0, 255) assert call.data.get(light.ATTR_BRIGHTNESS_PCT) == 20 + + +async def test_light_context(hass): + """Test that light context works.""" + assert await async_setup_component(hass, 'light', { + 'light': { + 'platform': 'test' + } + }) + + state = hass.states.get('light.ceiling') + assert state is not None + + await hass.services.async_call('light', 'toggle', { + 'entity_id': state.entity_id, + }, True, core.Context(user_id='abcd')) + + state2 = hass.states.get('light.ceiling') + assert state2 is not None + assert state.state != state2.state + assert state2.context.user_id == 'abcd' diff --git a/tests/components/switch/test_init.py b/tests/components/switch/test_init.py index d679aa2c827..55e44299294 100644 --- a/tests/components/switch/test_init.py +++ b/tests/components/switch/test_init.py @@ -2,8 +2,8 @@ # pylint: disable=protected-access import unittest -from homeassistant.setup import setup_component -from homeassistant import loader +from homeassistant.setup import setup_component, async_setup_component +from homeassistant import core, loader from homeassistant.components import switch from homeassistant.const import STATE_ON, STATE_OFF, CONF_PLATFORM @@ -91,3 +91,24 @@ class TestSwitch(unittest.TestCase): '{} 2'.format(switch.DOMAIN): {CONF_PLATFORM: 'test2'}, } )) + + +async def test_switch_context(hass): + """Test that switch context works.""" + assert await async_setup_component(hass, 'switch', { + 'switch': { + 'platform': 'test' + } + }) + + state = hass.states.get('switch.ac') + assert state is not None + + await hass.services.async_call('switch', 'toggle', { + 'entity_id': state.entity_id, + }, True, core.Context(user_id='abcd')) + + state2 = hass.states.get('switch.ac') + assert state2 is not None + assert state.state != state2.state + assert state2.context.user_id == 'abcd' diff --git a/tests/components/test_api.py b/tests/components/test_api.py index f53010ef27f..09dc27e97c1 100644 --- a/tests/components/test_api.py +++ b/tests/components/test_api.py @@ -12,6 +12,8 @@ from homeassistant.bootstrap import DATA_LOGGING import homeassistant.core as ha from homeassistant.setup import async_setup_component +from tests.common import async_mock_service + @pytest.fixture def mock_api_client(hass, aiohttp_client): @@ -429,3 +431,58 @@ async def test_api_error_log(hass, aiohttp_client): assert mock_file.mock_calls[0][1][0] == hass.data[DATA_LOGGING] assert resp.status == 200 assert await resp.text() == 'Hello' + + +async def test_api_fire_event_context(hass, mock_api_client, + hass_access_token): + """Test if the API sets right context if we fire an event.""" + test_value = [] + + @ha.callback + def listener(event): + """Helper method that will verify our event got called.""" + test_value.append(event) + + hass.bus.async_listen("test.event", listener) + + await mock_api_client.post( + const.URL_API_EVENTS_EVENT.format("test.event"), + headers={ + 'authorization': 'Bearer {}'.format(hass_access_token.token) + }) + await hass.async_block_till_done() + + assert len(test_value) == 1 + assert test_value[0].context.user_id == \ + hass_access_token.refresh_token.user.id + + +async def test_api_call_service_context(hass, mock_api_client, + hass_access_token): + """Test if the API sets right context if we call a service.""" + calls = async_mock_service(hass, 'test_domain', 'test_service') + + await mock_api_client.post( + '/api/services/test_domain/test_service', + headers={ + 'authorization': 'Bearer {}'.format(hass_access_token.token) + }) + await hass.async_block_till_done() + + assert len(calls) == 1 + assert calls[0].context.user_id == hass_access_token.refresh_token.user.id + + +async def test_api_set_state_context(hass, mock_api_client, hass_access_token): + """Test if the API sets right context if we set state.""" + await mock_api_client.post( + '/api/states/light.kitchen', + json={ + 'state': 'on' + }, + headers={ + 'authorization': 'Bearer {}'.format(hass_access_token.token) + }) + + state = hass.states.get('light.kitchen') + assert state.context.user_id == hass_access_token.refresh_token.user.id diff --git a/tests/components/test_mqtt_eventstream.py b/tests/components/test_mqtt_eventstream.py index 38924817980..8da1311c87d 100644 --- a/tests/components/test_mqtt_eventstream.py +++ b/tests/components/test_mqtt_eventstream.py @@ -104,12 +104,14 @@ class TestMqttEventStream: "state": "on", "entity_id": e_id, "attributes": {}, - "last_changed": now.isoformat() + "last_changed": now.isoformat(), } event['event_data'] = {"new_state": new_state, "entity_id": e_id} # Verify that the message received was that expected - assert json.loads(msg) == event + result = json.loads(msg) + result['event_data']['new_state'].pop('context') + assert result == event @patch('homeassistant.components.mqtt.async_publish') def test_time_event_does_not_send_message(self, mock_pub): diff --git a/tests/components/test_websocket_api.py b/tests/components/test_websocket_api.py index dc1688bae16..1fac1af9f64 100644 --- a/tests/components/test_websocket_api.py +++ b/tests/components/test_websocket_api.py @@ -10,7 +10,7 @@ from homeassistant.core import callback from homeassistant.components import websocket_api as wapi from homeassistant.setup import async_setup_component -from tests.common import mock_coro +from tests.common import mock_coro, async_mock_service API_PASSWORD = 'test1234' @@ -443,3 +443,94 @@ async def test_auth_with_invalid_token(hass, aiohttp_client): auth_msg = await ws.receive_json() assert auth_msg['type'] == wapi.TYPE_AUTH_INVALID + + +async def test_call_service_context_with_user(hass, aiohttp_client, + hass_access_token): + """Test that the user is set in the service call context.""" + assert await async_setup_component(hass, 'websocket_api', { + 'http': { + 'api_password': API_PASSWORD + } + }) + + calls = async_mock_service(hass, 'domain_test', 'test_service') + client = await aiohttp_client(hass.http.app) + + async with client.ws_connect(wapi.URL) as ws: + with patch('homeassistant.auth.AuthManager.active') as auth_active: + auth_active.return_value = True + auth_msg = await ws.receive_json() + assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED + + await ws.send_json({ + 'type': wapi.TYPE_AUTH, + 'access_token': hass_access_token.token + }) + + auth_msg = await ws.receive_json() + assert auth_msg['type'] == wapi.TYPE_AUTH_OK + + await ws.send_json({ + 'id': 5, + 'type': wapi.TYPE_CALL_SERVICE, + 'domain': 'domain_test', + 'service': 'test_service', + 'service_data': { + 'hello': 'world' + } + }) + + msg = await ws.receive_json() + assert msg['success'] + + assert len(calls) == 1 + call = calls[0] + assert call.domain == 'domain_test' + assert call.service == 'test_service' + assert call.data == {'hello': 'world'} + assert call.context.user_id == hass_access_token.refresh_token.user.id + + +async def test_call_service_context_no_user(hass, aiohttp_client): + """Test that connection without user sets context.""" + assert await async_setup_component(hass, 'websocket_api', { + 'http': { + 'api_password': API_PASSWORD + } + }) + + calls = async_mock_service(hass, 'domain_test', 'test_service') + client = await aiohttp_client(hass.http.app) + + async with client.ws_connect(wapi.URL) as ws: + auth_msg = await ws.receive_json() + assert auth_msg['type'] == wapi.TYPE_AUTH_REQUIRED + + await ws.send_json({ + 'type': wapi.TYPE_AUTH, + 'api_password': API_PASSWORD + }) + + auth_msg = await ws.receive_json() + assert auth_msg['type'] == wapi.TYPE_AUTH_OK + + await ws.send_json({ + 'id': 5, + 'type': wapi.TYPE_CALL_SERVICE, + 'domain': 'domain_test', + 'service': 'test_service', + 'service_data': { + 'hello': 'world' + } + }) + + msg = await ws.receive_json() + assert msg['success'] + + assert len(calls) == 1 + call = calls[0] + assert call.domain == 'domain_test' + assert call.service == 'test_service' + assert call.data == {'hello': 'world'} + assert call.context.user_id is None diff --git a/tests/components/zwave/test_init.py b/tests/components/zwave/test_init.py index e608dcccaba..39abf6f588f 100644 --- a/tests/components/zwave/test_init.py +++ b/tests/components/zwave/test_init.py @@ -163,10 +163,10 @@ def test_zwave_ready_wait(hass, mock_openzwave): asyncio_sleep = asyncio.sleep @asyncio.coroutine - def sleep(duration, loop): + def sleep(duration, loop=None): if duration > 0: sleeps.append(duration) - yield from asyncio_sleep(0, loop=loop) + yield from asyncio_sleep(0) with patch('homeassistant.components.zwave.dt_util.utcnow', new=utcnow): with patch('asyncio.sleep', new=sleep): @@ -248,10 +248,10 @@ async def test_unparsed_node_discovery(hass, mock_openzwave): asyncio_sleep = asyncio.sleep - async def sleep(duration, loop): + async def sleep(duration, loop=None): if duration > 0: sleeps.append(duration) - await asyncio_sleep(0, loop=loop) + await asyncio_sleep(0) with patch('homeassistant.components.zwave.dt_util.utcnow', new=utcnow): with patch('asyncio.sleep', new=sleep): diff --git a/tests/test_core.py b/tests/test_core.py index 7633c820d2d..9de801e0bb4 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -277,6 +277,10 @@ class TestEvent(unittest.TestCase): 'data': data, 'origin': 'LOCAL', 'time_fired': now, + 'context': { + 'id': event.context.id, + 'user_id': event.context.user_id, + }, } self.assertEqual(expected, event.as_dict()) @@ -598,18 +602,16 @@ class TestStateMachine(unittest.TestCase): self.assertEqual(1, len(events)) -class TestServiceCall(unittest.TestCase): - """Test ServiceCall class.""" +def test_service_call_repr(): + """Test ServiceCall repr.""" + call = ha.ServiceCall('homeassistant', 'start') + assert str(call) == \ + "".format(call.context.id) - def test_repr(self): - """Test repr method.""" - self.assertEqual( - "", - str(ha.ServiceCall('homeassistant', 'start'))) - - self.assertEqual( - "", - str(ha.ServiceCall('homeassistant', 'start', {"fast": "yes"}))) + call2 = ha.ServiceCall('homeassistant', 'start', {'fast': 'yes'}) + assert str(call2) == \ + "".format( + call2.context.id) class TestServiceRegistry(unittest.TestCase):