diff --git a/homeassistant/components/automation/__init__.py b/homeassistant/components/automation/__init__.py index d99043f0c75..6f5396afa15 100644 --- a/homeassistant/components/automation/__init__.py +++ b/homeassistant/components/automation/__init__.py @@ -4,19 +4,26 @@ Allow to setup simple automation rules via the config file. For more details about this component, please refer to the documentation at https://home-assistant.io/components/automation/ """ +from functools import partial import logging import voluptuous as vol from homeassistant.bootstrap import prepare_setup_platform -from homeassistant.const import ATTR_ENTITY_ID, CONF_PLATFORM +from homeassistant.const import ( + ATTR_ENTITY_ID, CONF_PLATFORM, STATE_ON, SERVICE_TURN_ON, SERVICE_TURN_OFF, + SERVICE_TOGGLE) from homeassistant.components import logbook from homeassistant.exceptions import HomeAssistantError from homeassistant.helpers import extract_domain_configs, script, condition +from homeassistant.helpers.entity import ToggleEntity +from homeassistant.helpers.entity_component import EntityComponent from homeassistant.loader import get_platform +from homeassistant.util.dt import utcnow import homeassistant.helpers.config_validation as cv DOMAIN = 'automation' +ENTITY_ID_FORMAT = DOMAIN + '.{}' DEPENDENCIES = ['group'] @@ -36,6 +43,10 @@ DEFAULT_CONDITION_TYPE = CONDITION_TYPE_AND METHOD_TRIGGER = 'trigger' METHOD_IF_ACTION = 'if_action' +ATTR_LAST_TRIGGERED = 'last_triggered' +ATTR_VARIABLES = 'variables' +SERVICE_TRIGGER = 'trigger' + _LOGGER = logging.getLogger(__name__) @@ -88,41 +99,171 @@ PLATFORM_SCHEMA = vol.Schema({ vol.Required(CONF_TRIGGER): _TRIGGER_SCHEMA, vol.Required(CONF_CONDITION_TYPE, default=DEFAULT_CONDITION_TYPE): vol.All(vol.Lower, vol.Any(CONDITION_TYPE_AND, CONDITION_TYPE_OR)), - CONF_CONDITION: _CONDITION_SCHEMA, + vol.Optional(CONF_CONDITION): _CONDITION_SCHEMA, vol.Required(CONF_ACTION): cv.SCRIPT_SCHEMA, }) +SERVICE_SCHEMA = vol.Schema({ + vol.Optional(ATTR_ENTITY_ID): cv.entity_ids, +}) + +TRIGGER_SERVICE_SCHEMA = vol.Schema({ + vol.Optional(ATTR_ENTITY_ID): cv.entity_ids, + vol.Optional(ATTR_VARIABLES, default={}): dict, +}) + + +def is_on(hass, entity_id=None): + """ + Return true if specified automation entity_id is on. + + Check all automation if no entity_id specified. + """ + entity_ids = [entity_id] if entity_id else hass.states.entity_ids(DOMAIN) + return any(hass.states.is_state(entity_id, STATE_ON) + for entity_id in entity_ids) + + +def turn_on(hass, entity_id=None): + """Turn on specified automation or all.""" + data = {ATTR_ENTITY_ID: entity_id} if entity_id else {} + hass.services.call(DOMAIN, SERVICE_TURN_ON, data) + + +def turn_off(hass, entity_id=None): + """Turn off specified automation or all.""" + data = {ATTR_ENTITY_ID: entity_id} if entity_id else {} + hass.services.call(DOMAIN, SERVICE_TURN_OFF, data) + + +def toggle(hass, entity_id=None): + """Toggle specified automation or all.""" + data = {ATTR_ENTITY_ID: entity_id} if entity_id else {} + hass.services.call(DOMAIN, SERVICE_TOGGLE, data) + + +def trigger(hass, entity_id=None): + """Trigger specified automation or all.""" + data = {ATTR_ENTITY_ID: entity_id} if entity_id else {} + hass.services.call(DOMAIN, SERVICE_TRIGGER, data) + def setup(hass, config): """Setup the automation.""" + # pylint: disable=too-many-locals + component = EntityComponent(_LOGGER, DOMAIN, hass) + success = False for config_key in extract_domain_configs(config, DOMAIN): conf = config[config_key] for list_no, config_block in enumerate(conf): - name = config_block.get(CONF_ALIAS, "{}, {}".format(config_key, - list_no)) - success = (_setup_automation(hass, config_block, name, config) or - success) + name = config_block.get(CONF_ALIAS) or "{} {}".format(config_key, + list_no) - return success + action = _get_action(hass, config_block.get(CONF_ACTION, {}), name) + if CONF_CONDITION in config_block: + cond_func = _process_if(hass, config, config_block) -def _setup_automation(hass, config_block, name, config): - """Setup one instance of automation.""" - action = _get_action(hass, config_block.get(CONF_ACTION, {}), name) + if cond_func is None: + continue + else: + def cond_func(variables): + """Condition will always pass.""" + return True - if CONF_CONDITION in config_block: - action = _process_if(hass, config, config_block, action) + attach_triggers = partial(_process_trigger, hass, config, + config_block.get(CONF_TRIGGER, []), name) + entity = AutomationEntity(name, attach_triggers, cond_func, action) + component.add_entities((entity,)) + success = True - if action is None: - return False + if not success: + return False + + def trigger_service_handler(service_call): + """Handle automation triggers.""" + for entity in component.extract_from_service(service_call): + entity.trigger(service_call.data.get(ATTR_VARIABLES)) + + def service_handler(service_call): + """Handle automation service calls.""" + for entity in component.extract_from_service(service_call): + getattr(entity, service_call.service)() + + hass.services.register(DOMAIN, SERVICE_TRIGGER, trigger_service_handler, + schema=TRIGGER_SERVICE_SCHEMA) + + for service in (SERVICE_TURN_ON, SERVICE_TURN_OFF, SERVICE_TOGGLE): + hass.services.register(DOMAIN, service, service_handler, + schema=SERVICE_SCHEMA) - _process_trigger(hass, config, config_block.get(CONF_TRIGGER, []), name, - action) return True +class AutomationEntity(ToggleEntity): + """Entity to show status of entity.""" + + def __init__(self, name, attach_triggers, cond_func, action): + """Initialize an automation entity.""" + self._name = name + self._attach_triggers = attach_triggers + self._detach_triggers = attach_triggers(self.trigger) + self._cond_func = cond_func + self._action = action + self._enabled = True + self._last_triggered = None + + @property + def name(self): + """Name of the automation.""" + return self._name + + @property + def should_poll(self): + """No polling needed for automation entities.""" + return False + + @property + def state_attributes(self): + """Return the entity state attributes.""" + return { + ATTR_LAST_TRIGGERED: self._last_triggered + } + + @property + def is_on(self) -> bool: + """Return True if entity is on.""" + return self._enabled + + def turn_on(self, **kwargs) -> None: + """Turn the entity on.""" + if self._enabled: + return + + self._detach_triggers = self._attach_triggers(self.trigger) + self._enabled = True + self.update_ha_state() + + def turn_off(self, **kwargs) -> None: + """Turn the entity off.""" + if not self._enabled: + return + + self._detach_triggers() + self._detach_triggers = None + self._enabled = False + self.update_ha_state() + + def trigger(self, variables): + """Trigger automation.""" + if self._cond_func(variables): + self._action(variables) + self._last_triggered = utcnow() + self.update_ha_state() + + def _get_action(hass, config, name): """Return an action based on a configuration.""" script_obj = script.Script(hass, config, name) @@ -136,7 +277,7 @@ def _get_action(hass, config, name): return action -def _process_if(hass, config, p_config, action): +def _process_if(hass, config, p_config): """Process if checks.""" cond_type = p_config.get(CONF_CONDITION_TYPE, DEFAULT_CONDITION_TYPE).lower() @@ -182,29 +323,43 @@ def _process_if(hass, config, p_config, action): if cond_type == CONDITION_TYPE_AND: def if_action(variables=None): """AND all conditions.""" - if all(check(hass, variables) for check in checks): - action(variables) + return all(check(hass, variables) for check in checks) else: def if_action(variables=None): """OR all conditions.""" - if any(check(hass, variables) for check in checks): - action(variables) + return any(check(hass, variables) for check in checks) return if_action def _process_trigger(hass, config, trigger_configs, name, action): """Setup the triggers.""" + removes = [] + for conf in trigger_configs: platform = _resolve_platform(METHOD_TRIGGER, hass, config, conf.get(CONF_PLATFORM)) if platform is None: continue - if platform.trigger(hass, conf, action): - _LOGGER.info("Initialized rule %s", name) - else: + remove = platform.trigger(hass, conf, action) + + if not remove: _LOGGER.error("Error setting up rule %s", name) + continue + + _LOGGER.info("Initialized rule %s", name) + removes.append(remove) + + if not removes: + return None + + def remove_triggers(): + """Remove attached triggers.""" + for remove in removes: + remove() + + return remove_triggers def _resolve_platform(method, hass, config, platform): diff --git a/homeassistant/components/automation/event.py b/homeassistant/components/automation/event.py index 6b3160996f3..795dd94a71f 100644 --- a/homeassistant/components/automation/event.py +++ b/homeassistant/components/automation/event.py @@ -39,5 +39,4 @@ def trigger(hass, config, action): }, }) - hass.bus.listen(event_type, handle_event) - return True + return hass.bus.listen(event_type, handle_event) diff --git a/homeassistant/components/automation/mqtt.py b/homeassistant/components/automation/mqtt.py index e4a6b221e04..5cd60ff0cea 100644 --- a/homeassistant/components/automation/mqtt.py +++ b/homeassistant/components/automation/mqtt.py @@ -39,6 +39,4 @@ def trigger(hass, config, action): } }) - mqtt.subscribe(hass, topic, mqtt_automation_listener) - - return True + return mqtt.subscribe(hass, topic, mqtt_automation_listener) diff --git a/homeassistant/components/automation/numeric_state.py b/homeassistant/components/automation/numeric_state.py index 3a148b0880f..608063b4708 100644 --- a/homeassistant/components/automation/numeric_state.py +++ b/homeassistant/components/automation/numeric_state.py @@ -63,7 +63,4 @@ def trigger(hass, config, action): action(variables) - track_state_change( - hass, entity_id, state_automation_listener) - - return True + return track_state_change(hass, entity_id, state_automation_listener) diff --git a/homeassistant/components/automation/state.py b/homeassistant/components/automation/state.py index 03902c1d6e2..8e0eb5231a5 100644 --- a/homeassistant/components/automation/state.py +++ b/homeassistant/components/automation/state.py @@ -7,8 +7,7 @@ at https://home-assistant.io/components/automation/#state-trigger import voluptuous as vol import homeassistant.util.dt as dt_util -from homeassistant.const import ( - EVENT_STATE_CHANGED, EVENT_TIME_CHANGED, MATCH_ALL, CONF_PLATFORM) +from homeassistant.const import MATCH_ALL, CONF_PLATFORM from homeassistant.helpers.event import track_state_change, track_point_in_time import homeassistant.helpers.config_validation as cv @@ -39,9 +38,13 @@ def trigger(hass, config, action): from_state = config.get(CONF_FROM, MATCH_ALL) to_state = config.get(CONF_TO) or config.get(CONF_STATE) or MATCH_ALL time_delta = config.get(CONF_FOR) + remove_state_for_cancel = None + remove_state_for_listener = None def state_automation_listener(entity, from_s, to_s): """Listen for state changes and calls action.""" + nonlocal remove_state_for_cancel, remove_state_for_listener + def call_action(): """Call action with right context.""" action({ @@ -60,26 +63,33 @@ def trigger(hass, config, action): def state_for_listener(now): """Fire on state changes after a delay and calls action.""" - hass.bus.remove_listener( - EVENT_STATE_CHANGED, attached_state_for_cancel) + remove_state_for_cancel() call_action() def state_for_cancel_listener(entity, inner_from_s, inner_to_s): """Fire on changes and cancel for listener if changed.""" if inner_to_s.state == to_s.state: return - hass.bus.remove_listener(EVENT_TIME_CHANGED, - attached_state_for_listener) - hass.bus.remove_listener(EVENT_STATE_CHANGED, - attached_state_for_cancel) + remove_state_for_listener() + remove_state_for_cancel() - attached_state_for_listener = track_point_in_time( + remove_state_for_listener = track_point_in_time( hass, state_for_listener, dt_util.utcnow() + time_delta) - attached_state_for_cancel = track_state_change( + remove_state_for_cancel = track_state_change( hass, entity, state_for_cancel_listener) - track_state_change( - hass, entity_id, state_automation_listener, from_state, to_state) + unsub = track_state_change(hass, entity_id, state_automation_listener, + from_state, to_state) - return True + def remove(): + """Remove state listeners.""" + unsub() + # pylint: disable=not-callable + if remove_state_for_cancel is not None: + remove_state_for_cancel() + + if remove_state_for_listener is not None: + remove_state_for_listener() + + return remove diff --git a/homeassistant/components/automation/sun.py b/homeassistant/components/automation/sun.py index 7666847575e..991f9b3b385 100644 --- a/homeassistant/components/automation/sun.py +++ b/homeassistant/components/automation/sun.py @@ -42,8 +42,6 @@ def trigger(hass, config, action): # Do something to call action if event == SUN_EVENT_SUNRISE: - track_sunrise(hass, call_action, offset) + return track_sunrise(hass, call_action, offset) else: - track_sunset(hass, call_action, offset) - - return True + return track_sunset(hass, call_action, offset) diff --git a/homeassistant/components/automation/template.py b/homeassistant/components/automation/template.py index 1cfbf45a24d..0891590a539 100644 --- a/homeassistant/components/automation/template.py +++ b/homeassistant/components/automation/template.py @@ -49,5 +49,4 @@ def trigger(hass, config, action): elif not template_result: already_triggered = False - track_state_change(hass, MATCH_ALL, state_changed_listener) - return True + return track_state_change(hass, MATCH_ALL, state_changed_listener) diff --git a/homeassistant/components/automation/time.py b/homeassistant/components/automation/time.py index ca80536ea96..0732e2b212c 100644 --- a/homeassistant/components/automation/time.py +++ b/homeassistant/components/automation/time.py @@ -47,7 +47,5 @@ def trigger(hass, config, action): }, }) - track_time_change(hass, time_automation_listener, - hour=hours, minute=minutes, second=seconds) - - return True + return track_time_change(hass, time_automation_listener, + hour=hours, minute=minutes, second=seconds) diff --git a/homeassistant/components/automation/zone.py b/homeassistant/components/automation/zone.py index 5578bf052c4..ec948684805 100644 --- a/homeassistant/components/automation/zone.py +++ b/homeassistant/components/automation/zone.py @@ -58,7 +58,5 @@ def trigger(hass, config, action): }, }) - track_state_change( - hass, entity_id, zone_automation_listener, MATCH_ALL, MATCH_ALL) - - return True + return track_state_change(hass, entity_id, zone_automation_listener, + MATCH_ALL, MATCH_ALL) diff --git a/homeassistant/components/cover/demo.py b/homeassistant/components/cover/demo.py index 1f1c666f339..acddfcf7c73 100644 --- a/homeassistant/components/cover/demo.py +++ b/homeassistant/components/cover/demo.py @@ -5,7 +5,6 @@ For more details about this platform, please refer to the documentation https://home-assistant.io/components/demo/ """ from homeassistant.components.cover import CoverDevice -from homeassistant.const import EVENT_TIME_CHANGED from homeassistant.helpers.event import track_utc_time_change @@ -32,8 +31,8 @@ class DemoCover(CoverDevice): self._tilt_position = tilt_position self._closing = True self._closing_tilt = True - self._listener_cover = None - self._listener_cover_tilt = None + self._unsub_listener_cover = None + self._unsub_listener_cover_tilt = None @property def name(self): @@ -120,10 +119,9 @@ class DemoCover(CoverDevice): """Stop the cover.""" if self._position is None: return - if self._listener_cover is not None: - self.hass.bus.remove_listener(EVENT_TIME_CHANGED, - self._listener_cover) - self._listener_cover = None + if self._unsub_listener_cover is not None: + self._unsub_listener_cover() + self._unsub_listener_cover = None self._set_position = None def stop_cover_tilt(self, **kwargs): @@ -131,16 +129,15 @@ class DemoCover(CoverDevice): if self._tilt_position is None: return - if self._listener_cover_tilt is not None: - self.hass.bus.remove_listener(EVENT_TIME_CHANGED, - self._listener_cover_tilt) - self._listener_cover_tilt = None + if self._unsub_listener_cover_tilt is not None: + self._unsub_listener_cover_tilt() + self._unsub_listener_cover_tilt = None self._set_tilt_position = None def _listen_cover(self): """Listen for changes in cover.""" - if self._listener_cover is None: - self._listener_cover = track_utc_time_change( + if self._unsub_listener_cover is None: + self._unsub_listener_cover = track_utc_time_change( self.hass, self._time_changed_cover) def _time_changed_cover(self, now): @@ -156,8 +153,8 @@ class DemoCover(CoverDevice): def _listen_cover_tilt(self): """Listen for changes in cover tilt.""" - if self._listener_cover_tilt is None: - self._listener_cover_tilt = track_utc_time_change( + if self._unsub_listener_cover_tilt is None: + self._unsub_listener_cover_tilt = track_utc_time_change( self.hass, self._time_changed_cover_tilt) def _time_changed_cover_tilt(self, now): diff --git a/homeassistant/components/group.py b/homeassistant/components/group.py index be998b48f23..4444b97ebe2 100644 --- a/homeassistant/components/group.py +++ b/homeassistant/components/group.py @@ -175,6 +175,7 @@ class Group(Entity): self.group_off = None self._assumed_state = False self._lock = threading.Lock() + self._unsub_state_changed = None if entity_ids is not None: self.update_tracked_entity_ids(entity_ids) @@ -236,15 +237,16 @@ class Group(Entity): def start(self): """Start tracking members.""" - track_state_change( + self._unsub_state_changed = track_state_change( self.hass, self.tracking, self._state_changed_listener) def stop(self): """Unregister the group from Home Assistant.""" self.hass.states.remove(self.entity_id) - self.hass.bus.remove_listener( - ha.EVENT_STATE_CHANGED, self._state_changed_listener) + if self._unsub_state_changed: + self._unsub_state_changed() + self._unsub_state_changed = None def update(self): """Query all members and determine current group state.""" diff --git a/homeassistant/components/mqtt/__init__.py b/homeassistant/components/mqtt/__init__.py index e06f60b6e1a..6cf8ed047ee 100644 --- a/homeassistant/components/mqtt/__init__.py +++ b/homeassistant/components/mqtt/__init__.py @@ -170,9 +170,14 @@ def subscribe(hass, topic, callback, qos=DEFAULT_QOS): callback(event.data[ATTR_TOPIC], event.data[ATTR_PAYLOAD], event.data[ATTR_QOS]) - hass.bus.listen(EVENT_MQTT_MESSAGE_RECEIVED, mqtt_topic_subscriber) + remove = hass.bus.listen(EVENT_MQTT_MESSAGE_RECEIVED, + mqtt_topic_subscriber) + + # Future: track subscriber count and unsubscribe in remove MQTT_CLIENT.subscribe(topic, qos) + return remove + def _setup_server(hass, config): """Try to start embedded MQTT broker.""" diff --git a/homeassistant/components/recorder/models.py b/homeassistant/components/recorder/models.py index 6e3e2db064d..671623ec564 100644 --- a/homeassistant/components/recorder/models.py +++ b/homeassistant/components/recorder/models.py @@ -92,7 +92,8 @@ class States(Base): # type: ignore else: dbstate.domain = state.domain dbstate.state = state.state - dbstate.attributes = json.dumps(dict(state.attributes)) + dbstate.attributes = json.dumps(dict(state.attributes), + cls=JSONEncoder) dbstate.last_changed = state.last_changed dbstate.last_updated = state.last_updated diff --git a/homeassistant/components/rollershutter/demo.py b/homeassistant/components/rollershutter/demo.py index 31915019c5e..6799d062e43 100644 --- a/homeassistant/components/rollershutter/demo.py +++ b/homeassistant/components/rollershutter/demo.py @@ -5,7 +5,6 @@ For more details about this platform, please refer to the documentation https://home-assistant.io/components/demo/ """ from homeassistant.components.rollershutter import RollershutterDevice -from homeassistant.const import EVENT_TIME_CHANGED from homeassistant.helpers.event import track_utc_time_change @@ -27,7 +26,7 @@ class DemoRollershutter(RollershutterDevice): self._name = name self._position = position self._moving_up = True - self._listener = None + self._unsub_listener = None @property def name(self): @@ -70,15 +69,15 @@ class DemoRollershutter(RollershutterDevice): def stop(self, **kwargs): """Stop the roller shutter.""" - if self._listener is not None: - self.hass.bus.remove_listener(EVENT_TIME_CHANGED, self._listener) - self._listener = None + if self._unsub_listener is not None: + self._unsub_listener() + self._unsub_listener = None def _listen(self): """Listen for changes.""" - if self._listener is None: - self._listener = track_utc_time_change(self.hass, - self._time_changed) + if self._unsub_listener is None: + self._unsub_listener = track_utc_time_change(self.hass, + self._time_changed) def _time_changed(self, now): """Track time changes.""" diff --git a/homeassistant/core.py b/homeassistant/core.py index b77d8356a35..dad7313bb82 100644 --- a/homeassistant/core.py +++ b/homeassistant/core.py @@ -297,6 +297,12 @@ class EventBus(object): else: self._listeners[event_type] = [listener] + def remove_listener(): + """Remove the listener.""" + self.remove_listener(event_type, listener) + + return remove_listener + def listen_once(self, event_type, listener): """Listen once for event of a specific type. diff --git a/homeassistant/helpers/event.py b/homeassistant/helpers/event.py index 9bc6910c685..ff81b693704 100644 --- a/homeassistant/helpers/event.py +++ b/homeassistant/helpers/event.py @@ -14,8 +14,7 @@ def track_state_change(hass, entity_ids, action, from_state=None, entity_ids, from_state and to_state can be string or list. Use list to match multiple. - Returns the listener that listens on the bus for EVENT_STATE_CHANGED. - Pass the return value into hass.bus.remove_listener to remove it. + Returns a function that can be called to remove the listener. """ from_state = _process_state_match(from_state) to_state = _process_state_match(to_state) @@ -50,9 +49,7 @@ def track_state_change(hass, entity_ids, action, from_state=None, event.data.get('old_state'), event.data.get('new_state')) - hass.bus.listen(EVENT_STATE_CHANGED, state_change_listener) - - return state_change_listener + return hass.bus.listen(EVENT_STATE_CHANGED, state_change_listener) def track_point_in_time(hass, action, point_in_time): @@ -77,23 +74,20 @@ def track_point_in_utc_time(hass, action, point_in_time): """Listen for matching time_changed events.""" now = event.data[ATTR_NOW] - if now >= point_in_time and \ - not hasattr(point_in_time_listener, 'run'): + if now < point_in_time or hasattr(point_in_time_listener, 'run'): + return - # Set variable so that we will never run twice. - # Because the event bus might have to wait till a thread comes - # available to execute this listener it might occur that the - # listener gets lined up twice to be executed. This will make - # sure the second time it does nothing. - point_in_time_listener.run = True + # Set variable so that we will never run twice. + # Because the event bus might have to wait till a thread comes + # available to execute this listener it might occur that the + # listener gets lined up twice to be executed. This will make + # sure the second time it does nothing. + point_in_time_listener.run = True + remove() + action(now) - hass.bus.remove_listener(EVENT_TIME_CHANGED, - point_in_time_listener) - - action(now) - - hass.bus.listen(EVENT_TIME_CHANGED, point_in_time_listener) - return point_in_time_listener + remove = hass.bus.listen(EVENT_TIME_CHANGED, point_in_time_listener) + return remove def track_sunrise(hass, action, offset=None): @@ -112,10 +106,18 @@ def track_sunrise(hass, action, offset=None): def sunrise_automation_listener(now): """Called when it's time for action.""" + nonlocal remove track_point_in_utc_time(hass, sunrise_automation_listener, next_rise()) action() - track_point_in_utc_time(hass, sunrise_automation_listener, next_rise()) + remove = track_point_in_utc_time(hass, sunrise_automation_listener, + next_rise()) + + def remove_listener(): + """Remove sunrise listener.""" + remove() + + return remove_listener def track_sunset(hass, action, offset=None): @@ -134,10 +136,19 @@ def track_sunset(hass, action, offset=None): def sunset_automation_listener(now): """Called when it's time for action.""" - track_point_in_utc_time(hass, sunset_automation_listener, next_set()) + nonlocal remove + remove = track_point_in_utc_time(hass, sunset_automation_listener, + next_set()) action() - track_point_in_utc_time(hass, sunset_automation_listener, next_set()) + remove = track_point_in_utc_time(hass, sunset_automation_listener, + next_set()) + + def remove_listener(): + """Remove sunset listener.""" + remove() + + return remove_listener # pylint: disable=too-many-arguments @@ -152,8 +163,7 @@ def track_utc_time_change(hass, action, year=None, month=None, day=None, """Fire every time event that comes in.""" action(event.data[ATTR_NOW]) - hass.bus.listen(EVENT_TIME_CHANGED, time_change_listener) - return time_change_listener + return hass.bus.listen(EVENT_TIME_CHANGED, time_change_listener) pmp = _process_time_match year, month, day = pmp(year), pmp(month), pmp(day) @@ -178,8 +188,7 @@ def track_utc_time_change(hass, action, year=None, month=None, day=None, action(now) - hass.bus.listen(EVENT_TIME_CHANGED, pattern_time_change_listener) - return pattern_time_change_listener + return hass.bus.listen(EVENT_TIME_CHANGED, pattern_time_change_listener) # pylint: disable=too-many-arguments diff --git a/homeassistant/helpers/script.py b/homeassistant/helpers/script.py index 008fdb9374d..73ef08ce1ff 100644 --- a/homeassistant/helpers/script.py +++ b/homeassistant/helpers/script.py @@ -7,7 +7,7 @@ from typing import Optional, Sequence import voluptuous as vol from homeassistant.core import HomeAssistant -from homeassistant.const import EVENT_TIME_CHANGED, CONF_CONDITION +from homeassistant.const import CONF_CONDITION from homeassistant.helpers import ( service, condition, template, config_validation as cv) from homeassistant.helpers.event import track_point_in_utc_time @@ -47,7 +47,7 @@ class Script(): self.can_cancel = any(CONF_DELAY in action for action in self.sequence) self._lock = threading.Lock() - self._delay_listener = None + self._unsub_delay_listener = None @property def is_running(self) -> bool: @@ -72,7 +72,7 @@ class Script(): # Call ourselves in the future to continue work def script_delay(now): """Called after delay is done.""" - self._delay_listener = None + self._unsub_delay_listener = None self.run(variables) delay = action[CONF_DELAY] @@ -83,7 +83,7 @@ class Script(): cv.positive_timedelta)( template.render(self.hass, delay)) - self._delay_listener = track_point_in_utc_time( + self._unsub_delay_listener = track_point_in_utc_time( self.hass, script_delay, date_util.utcnow() + delay) self._cur = cur + 1 @@ -139,10 +139,9 @@ class Script(): def _remove_listener(self): """Remove point in time listener, if any.""" - if self._delay_listener: - self.hass.bus.remove_listener(EVENT_TIME_CHANGED, - self._delay_listener) - self._delay_listener = None + if self._unsub_delay_listener: + self._unsub_delay_listener() + self._unsub_delay_listener = None def _log(self, msg): """Logger helper.""" diff --git a/tests/common.py b/tests/common.py index e51e4ba048a..c82f6c13a0f 100644 --- a/tests/common.py +++ b/tests/common.py @@ -44,6 +44,7 @@ def get_test_home_assistant(num_threads=None): hass.config.elevation = 0 hass.config.time_zone = date_util.get_time_zone('US/Pacific') hass.config.units = METRIC_SYSTEM + hass.config.skip_pip = True if 'custom_components.test' not in loader.AVAILABLE_COMPONENTS: loader.prepare(hass) diff --git a/tests/components/automation/test_event.py b/tests/components/automation/test_event.py index ef5d380075b..80b1f507651 100644 --- a/tests/components/automation/test_event.py +++ b/tests/components/automation/test_event.py @@ -44,6 +44,13 @@ class TestAutomationEvent(unittest.TestCase): self.hass.pool.block_till_done() self.assertEqual(1, len(self.calls)) + automation.turn_off(self.hass) + self.hass.pool.block_till_done() + + self.hass.bus.fire('test_event') + self.hass.pool.block_till_done() + self.assertEqual(1, len(self.calls)) + def test_if_fires_on_event_with_data(self): """Test the firing of events with data.""" assert _setup_component(self.hass, automation.DOMAIN, { diff --git a/tests/components/automation/test_init.py b/tests/components/automation/test_init.py index e90ffe8d765..77727ca56b5 100644 --- a/tests/components/automation/test_init.py +++ b/tests/components/automation/test_init.py @@ -1,9 +1,11 @@ """The tests for the automation component.""" import unittest +from unittest.mock import patch from homeassistant.bootstrap import _setup_component import homeassistant.components.automation as automation from homeassistant.const import ATTR_ENTITY_ID +import homeassistant.util.dt as dt_util from tests.common import get_test_home_assistant @@ -45,6 +47,7 @@ class TestAutomation(unittest.TestCase): """Test service data.""" assert _setup_component(self.hass, automation.DOMAIN, { automation.DOMAIN: { + 'alias': 'hello', 'trigger': { 'platform': 'event', 'event_type': 'test_event', @@ -59,10 +62,17 @@ class TestAutomation(unittest.TestCase): } }) - self.hass.bus.fire('test_event') - self.hass.pool.block_till_done() - self.assertEqual(1, len(self.calls)) - self.assertEqual('event - test_event', self.calls[0].data['some']) + time = dt_util.utcnow() + + with patch('homeassistant.components.automation.utcnow', + return_value=time): + self.hass.bus.fire('test_event') + self.hass.pool.block_till_done() + assert len(self.calls) == 1 + assert 'event - test_event' == self.calls[0].data['some'] + state = self.hass.states.get('automation.hello') + assert state is not None + assert state.attributes.get('last_triggered') == time def test_service_specify_entity_id(self): """Test service data.""" @@ -347,3 +357,60 @@ class TestAutomation(unittest.TestCase): assert len(self.calls) == 2 assert self.calls[0].data['position'] == 0 assert self.calls[1].data['position'] == 1 + + def test_services(self): + """Test the automation services for turning entities on/off.""" + entity_id = 'automation.hello' + + assert self.hass.states.get(entity_id) is None + assert not automation.is_on(self.hass, entity_id) + + assert _setup_component(self.hass, automation.DOMAIN, { + automation.DOMAIN: { + 'alias': 'hello', + 'trigger': { + 'platform': 'event', + 'event_type': 'test_event', + }, + 'action': { + 'service': 'test.automation', + } + } + }) + + assert self.hass.states.get(entity_id) is not None + assert automation.is_on(self.hass, entity_id) + + self.hass.bus.fire('test_event') + self.hass.pool.block_till_done() + assert len(self.calls) == 1 + + automation.turn_off(self.hass, entity_id) + self.hass.pool.block_till_done() + + assert not automation.is_on(self.hass, entity_id) + self.hass.bus.fire('test_event') + self.hass.pool.block_till_done() + assert len(self.calls) == 1 + + automation.toggle(self.hass, entity_id) + self.hass.pool.block_till_done() + + assert automation.is_on(self.hass, entity_id) + self.hass.bus.fire('test_event') + self.hass.pool.block_till_done() + assert len(self.calls) == 2 + + automation.trigger(self.hass, entity_id) + self.hass.pool.block_till_done() + assert len(self.calls) == 3 + + automation.turn_off(self.hass, entity_id) + self.hass.pool.block_till_done() + automation.trigger(self.hass, entity_id) + self.hass.pool.block_till_done() + assert len(self.calls) == 4 + + automation.turn_on(self.hass, entity_id) + self.hass.pool.block_till_done() + assert automation.is_on(self.hass, entity_id) diff --git a/tests/components/automation/test_mqtt.py b/tests/components/automation/test_mqtt.py index 29d55b424f2..9bd22d0675c 100644 --- a/tests/components/automation/test_mqtt.py +++ b/tests/components/automation/test_mqtt.py @@ -50,6 +50,12 @@ class TestAutomationMQTT(unittest.TestCase): self.assertEqual('mqtt - test-topic - test_payload', self.calls[0].data['some']) + automation.turn_off(self.hass) + self.hass.pool.block_till_done() + fire_mqtt_message(self.hass, 'test-topic', 'test_payload') + self.hass.pool.block_till_done() + self.assertEqual(1, len(self.calls)) + def test_if_fires_on_topic_and_payload_match(self): """Test if message is fired on topic and payload match.""" assert _setup_component(self.hass, automation.DOMAIN, { diff --git a/tests/components/automation/test_numeric_state.py b/tests/components/automation/test_numeric_state.py index f7d1447632f..9ee8514052c 100644 --- a/tests/components/automation/test_numeric_state.py +++ b/tests/components/automation/test_numeric_state.py @@ -45,6 +45,14 @@ class TestAutomationNumericState(unittest.TestCase): self.hass.pool.block_till_done() self.assertEqual(1, len(self.calls)) + # Set above 12 so the automation will fire again + self.hass.states.set('test.entity', 12) + automation.turn_off(self.hass) + self.hass.pool.block_till_done() + self.hass.states.set('test.entity', 9) + self.hass.pool.block_till_done() + self.assertEqual(1, len(self.calls)) + def test_if_fires_on_entity_change_over_to_below(self): """"Test the firing with changed entity.""" self.hass.states.set('test.entity', 11) diff --git a/tests/components/automation/test_state.py b/tests/components/automation/test_state.py index 4a6971124b6..0b715cb365c 100644 --- a/tests/components/automation/test_state.py +++ b/tests/components/automation/test_state.py @@ -59,6 +59,12 @@ class TestAutomationState(unittest.TestCase): 'state - test.entity - hello - world - None', self.calls[0].data['some']) + automation.turn_off(self.hass) + self.hass.pool.block_till_done() + self.hass.states.set('test.entity', 'planet') + self.hass.pool.block_till_done() + self.assertEqual(1, len(self.calls)) + def test_if_fires_on_entity_change_with_from_filter(self): """Test for firing on entity change with filter.""" assert _setup_component(self.hass, automation.DOMAIN, { diff --git a/tests/components/automation/test_sun.py b/tests/components/automation/test_sun.py index 745e7c060ca..d3bbd254e1b 100644 --- a/tests/components/automation/test_sun.py +++ b/tests/components/automation/test_sun.py @@ -54,6 +54,18 @@ class TestAutomationSun(unittest.TestCase): } }) + automation.turn_off(self.hass) + self.hass.pool.block_till_done() + + fire_time_changed(self.hass, trigger_time) + self.hass.pool.block_till_done() + self.assertEqual(0, len(self.calls)) + + with patch('homeassistant.util.dt.utcnow', + return_value=now): + automation.turn_on(self.hass) + self.hass.pool.block_till_done() + fire_time_changed(self.hass, trigger_time) self.hass.pool.block_till_done() self.assertEqual(1, len(self.calls)) diff --git a/tests/components/automation/test_template.py b/tests/components/automation/test_template.py index a643b731492..a33da951cc8 100644 --- a/tests/components/automation/test_template.py +++ b/tests/components/automation/test_template.py @@ -45,6 +45,13 @@ class TestAutomationTemplate(unittest.TestCase): self.hass.pool.block_till_done() self.assertEqual(1, len(self.calls)) + automation.turn_off(self.hass) + self.hass.pool.block_till_done() + + self.hass.states.set('test.entity', 'planet') + self.hass.pool.block_till_done() + self.assertEqual(1, len(self.calls)) + def test_if_fires_on_change_str(self): """Test for firing on change.""" assert _setup_component(self.hass, automation.DOMAIN, { @@ -149,6 +156,9 @@ class TestAutomationTemplate(unittest.TestCase): } }) + self.hass.pool.block_till_done() + self.calls = [] + self.hass.states.set('test.entity', 'hello') self.hass.pool.block_till_done() self.assertEqual(0, len(self.calls)) @@ -209,9 +219,12 @@ class TestAutomationTemplate(unittest.TestCase): } }) + self.hass.pool.block_till_done() + self.calls = [] + self.hass.states.set('test.entity', 'world') self.hass.pool.block_till_done() - self.assertEqual(0, len(self.calls)) + assert len(self.calls) == 0 def test_if_fires_on_change_with_template_advanced(self): """Test for firing on change with template advanced.""" @@ -237,6 +250,9 @@ class TestAutomationTemplate(unittest.TestCase): } }) + self.hass.pool.block_till_done() + self.calls = [] + self.hass.states.set('test.entity', 'world') self.hass.pool.block_till_done() self.assertEqual(1, len(self.calls)) @@ -287,29 +303,32 @@ class TestAutomationTemplate(unittest.TestCase): } }) + self.hass.pool.block_till_done() + self.calls = [] + self.hass.states.set('test.entity', 'world') self.hass.pool.block_till_done() - self.assertEqual(0, len(self.calls)) + assert len(self.calls) == 0 self.hass.states.set('test.entity', 'home') self.hass.pool.block_till_done() - self.assertEqual(1, len(self.calls)) + assert len(self.calls) == 1 self.hass.states.set('test.entity', 'work') self.hass.pool.block_till_done() - self.assertEqual(1, len(self.calls)) + assert len(self.calls) == 1 self.hass.states.set('test.entity', 'not_home') self.hass.pool.block_till_done() - self.assertEqual(1, len(self.calls)) + assert len(self.calls) == 1 self.hass.states.set('test.entity', 'world') self.hass.pool.block_till_done() - self.assertEqual(1, len(self.calls)) + assert len(self.calls) == 1 self.hass.states.set('test.entity', 'home') self.hass.pool.block_till_done() - self.assertEqual(2, len(self.calls)) + assert len(self.calls) == 2 def test_if_action(self): """Test for firing if action.""" diff --git a/tests/components/automation/test_time.py b/tests/components/automation/test_time.py index b36ce8c92b5..3c195f2eb38 100644 --- a/tests/components/automation/test_time.py +++ b/tests/components/automation/test_time.py @@ -43,7 +43,13 @@ class TestAutomationTime(unittest.TestCase): }) fire_time_changed(self.hass, dt_util.utcnow().replace(hour=0)) + self.hass.pool.block_till_done() + self.assertEqual(1, len(self.calls)) + automation.turn_off(self.hass) + self.hass.pool.block_till_done() + + fire_time_changed(self.hass, dt_util.utcnow().replace(hour=0)) self.hass.pool.block_till_done() self.assertEqual(1, len(self.calls)) diff --git a/tests/components/automation/test_zone.py b/tests/components/automation/test_zone.py index 24980b466bf..9d4161547ef 100644 --- a/tests/components/automation/test_zone.py +++ b/tests/components/automation/test_zone.py @@ -74,6 +74,24 @@ class TestAutomationZone(unittest.TestCase): 'zone - test.entity - hello - hello - test', self.calls[0].data['some']) + # Set out of zone again so we can trigger call + self.hass.states.set('test.entity', 'hello', { + 'latitude': 32.881011, + 'longitude': -117.234758 + }) + self.hass.pool.block_till_done() + + automation.turn_off(self.hass) + self.hass.pool.block_till_done() + + self.hass.states.set('test.entity', 'hello', { + 'latitude': 32.880586, + 'longitude': -117.237564 + }) + self.hass.pool.block_till_done() + + self.assertEqual(1, len(self.calls)) + def test_if_not_fires_for_enter_on_zone_leave(self): """Test for not firing on zone leave.""" self.hass.states.set('test.entity', 'hello', { diff --git a/tests/components/mqtt/test_init.py b/tests/components/mqtt/test_init.py index 4c5f14bf0f1..3678585141d 100644 --- a/tests/components/mqtt/test_init.py +++ b/tests/components/mqtt/test_init.py @@ -149,7 +149,7 @@ class TestMQTT(unittest.TestCase): def test_subscribe_topic(self): """Test the subscription of a topic.""" - mqtt.subscribe(self.hass, 'test-topic', self.record_calls) + unsub = mqtt.subscribe(self.hass, 'test-topic', self.record_calls) fire_mqtt_message(self.hass, 'test-topic', 'test-payload') @@ -158,6 +158,13 @@ class TestMQTT(unittest.TestCase): self.assertEqual('test-topic', self.calls[0][0]) self.assertEqual('test-payload', self.calls[0][1]) + unsub() + + fire_mqtt_message(self.hass, 'test-topic', 'test-payload') + + self.hass.pool.block_till_done() + self.assertEqual(1, len(self.calls)) + def test_subscribe_topic_not_match(self): """Test if subscribed topic is not a match.""" mqtt.subscribe(self.hass, 'test-topic', self.record_calls) diff --git a/tests/helpers/test_event.py b/tests/helpers/test_event.py index 5d9f8d28e20..704a501eefc 100644 --- a/tests/helpers/test_event.py +++ b/tests/helpers/test_event.py @@ -65,13 +65,21 @@ class TestEventHelpers(unittest.TestCase): self.hass.pool.block_till_done() self.assertEqual(2, len(runs)) + unsub = track_point_in_time( + self.hass, lambda x: runs.append(1), birthday_paulus) + unsub() + + self._send_time_changed(after_birthday) + self.hass.pool.block_till_done() + self.assertEqual(2, len(runs)) + def test_track_time_change(self): """Test tracking time change.""" wildcard_runs = [] specific_runs = [] - track_time_change(self.hass, lambda x: wildcard_runs.append(1)) - track_utc_time_change( + unsub = track_time_change(self.hass, lambda x: wildcard_runs.append(1)) + unsub_utc = track_utc_time_change( self.hass, lambda x: specific_runs.append(1), second=[0, 30]) self._send_time_changed(datetime(2014, 5, 24, 12, 0, 0)) @@ -89,6 +97,14 @@ class TestEventHelpers(unittest.TestCase): self.assertEqual(2, len(specific_runs)) self.assertEqual(3, len(wildcard_runs)) + unsub() + unsub_utc() + + self._send_time_changed(datetime(2014, 5, 24, 12, 0, 30)) + self.hass.pool.block_till_done() + self.assertEqual(2, len(specific_runs)) + self.assertEqual(3, len(wildcard_runs)) + def test_track_state_change(self): """Test track_state_change.""" # 2 lists to track how often our callbacks get called @@ -186,11 +202,12 @@ class TestEventHelpers(unittest.TestCase): # Track sunrise runs = [] - track_sunrise(self.hass, lambda: runs.append(1)) + unsub = track_sunrise(self.hass, lambda: runs.append(1)) offset_runs = [] offset = timedelta(minutes=30) - track_sunrise(self.hass, lambda: offset_runs.append(1), offset) + unsub2 = track_sunrise(self.hass, lambda: offset_runs.append(1), + offset) # run tests self._send_time_changed(next_rising - offset) @@ -208,6 +225,14 @@ class TestEventHelpers(unittest.TestCase): self.assertEqual(2, len(runs)) self.assertEqual(1, len(offset_runs)) + unsub() + unsub2() + + self._send_time_changed(next_rising + offset) + self.hass.pool.block_till_done() + self.assertEqual(2, len(runs)) + self.assertEqual(1, len(offset_runs)) + def test_track_sunset(self): """Test track the sunset.""" latitude = 32.87336 @@ -232,11 +257,11 @@ class TestEventHelpers(unittest.TestCase): # Track sunset runs = [] - track_sunset(self.hass, lambda: runs.append(1)) + unsub = track_sunset(self.hass, lambda: runs.append(1)) offset_runs = [] offset = timedelta(minutes=30) - track_sunset(self.hass, lambda: offset_runs.append(1), offset) + unsub2 = track_sunset(self.hass, lambda: offset_runs.append(1), offset) # Run tests self._send_time_changed(next_setting - offset) @@ -254,6 +279,14 @@ class TestEventHelpers(unittest.TestCase): self.assertEqual(2, len(runs)) self.assertEqual(1, len(offset_runs)) + unsub() + unsub2() + + self._send_time_changed(next_setting + offset) + self.hass.pool.block_till_done() + self.assertEqual(2, len(runs)) + self.assertEqual(1, len(offset_runs)) + def _send_time_changed(self, now): """Send a time changed event.""" self.hass.bus.fire(ha.EVENT_TIME_CHANGED, {ha.ATTR_NOW: now}) @@ -262,7 +295,7 @@ class TestEventHelpers(unittest.TestCase): """Test periodic tasks per minute.""" specific_runs = [] - track_utc_time_change( + unsub = track_utc_time_change( self.hass, lambda x: specific_runs.append(1), minute='/5') self._send_time_changed(datetime(2014, 5, 24, 12, 0, 0)) @@ -277,11 +310,17 @@ class TestEventHelpers(unittest.TestCase): self.hass.pool.block_till_done() self.assertEqual(2, len(specific_runs)) + unsub() + + self._send_time_changed(datetime(2014, 5, 24, 12, 5, 0)) + self.hass.pool.block_till_done() + self.assertEqual(2, len(specific_runs)) + def test_periodic_task_hour(self): """Test periodic tasks per hour.""" specific_runs = [] - track_utc_time_change( + unsub = track_utc_time_change( self.hass, lambda x: specific_runs.append(1), hour='/2') self._send_time_changed(datetime(2014, 5, 24, 22, 0, 0)) @@ -304,11 +343,17 @@ class TestEventHelpers(unittest.TestCase): self.hass.pool.block_till_done() self.assertEqual(3, len(specific_runs)) + unsub() + + self._send_time_changed(datetime(2014, 5, 25, 2, 0, 0)) + self.hass.pool.block_till_done() + self.assertEqual(3, len(specific_runs)) + def test_periodic_task_day(self): """Test periodic tasks per day.""" specific_runs = [] - track_utc_time_change( + unsub = track_utc_time_change( self.hass, lambda x: specific_runs.append(1), day='/2') self._send_time_changed(datetime(2014, 5, 2, 0, 0, 0)) @@ -323,11 +368,17 @@ class TestEventHelpers(unittest.TestCase): self.hass.pool.block_till_done() self.assertEqual(2, len(specific_runs)) + unsub() + + self._send_time_changed(datetime(2014, 5, 4, 0, 0, 0)) + self.hass.pool.block_till_done() + self.assertEqual(2, len(specific_runs)) + def test_periodic_task_year(self): """Test periodic tasks per year.""" specific_runs = [] - track_utc_time_change( + unsub = track_utc_time_change( self.hass, lambda x: specific_runs.append(1), year='/2') self._send_time_changed(datetime(2014, 5, 2, 0, 0, 0)) @@ -342,6 +393,12 @@ class TestEventHelpers(unittest.TestCase): self.hass.pool.block_till_done() self.assertEqual(2, len(specific_runs)) + unsub() + + self._send_time_changed(datetime(2016, 5, 2, 0, 0, 0)) + self.hass.pool.block_till_done() + self.assertEqual(2, len(specific_runs)) + def test_periodic_task_wrong_input(self): """Test periodic tasks with wrong input.""" specific_runs = [] diff --git a/tests/test_bootstrap.py b/tests/test_bootstrap.py index d41dc60ee15..f9abe764866 100644 --- a/tests/test_bootstrap.py +++ b/tests/test_bootstrap.py @@ -177,6 +177,7 @@ class TestBootstrap: return_value=False) def test_component_not_installed_if_requirement_fails(self, mock_install): """Component setup should fail if requirement can't install.""" + self.hass.config.skip_pip = False loader.set_component( 'comp', MockModule('comp', requirements=['package==0.0.1'])) diff --git a/tests/test_core.py b/tests/test_core.py index aa3cdd2aecc..0a67d933119 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -175,6 +175,29 @@ class TestEventBus(unittest.TestCase): # Try deleting listener while category doesn't exist either self.bus.remove_listener('test', listener) + def test_unsubscribe_listener(self): + """Test unsubscribe listener from returned function.""" + self.bus._pool.add_worker() + calls = [] + + def listener(event): + """Mock listener.""" + calls.append(event) + + unsub = self.bus.listen('test', listener) + + self.bus.fire('test') + self.bus._pool.block_till_done() + + assert len(calls) == 1 + + unsub() + + self.bus.fire('event') + self.bus._pool.block_till_done() + + assert len(calls) == 1 + def test_listen_once_event(self): """Test listen_once_event method.""" runs = []