diff --git a/homeassistant/components/history.py b/homeassistant/components/history.py index d8ff307fdde..c4eada498da 100644 --- a/homeassistant/components/history.py +++ b/homeassistant/components/history.py @@ -15,7 +15,6 @@ import voluptuous as vol from homeassistant.const import ( HTTP_BAD_REQUEST, CONF_DOMAINS, CONF_ENTITIES, CONF_EXCLUDE, CONF_INCLUDE) -import homeassistant.helpers.config_validation as cv import homeassistant.util.dt as dt_util from homeassistant.components import recorder, script from homeassistant.components.frontend import register_built_in_panel @@ -28,34 +27,22 @@ DOMAIN = 'history' DEPENDENCIES = ['recorder', 'http'] CONFIG_SCHEMA = vol.Schema({ - DOMAIN: vol.Schema({ - CONF_EXCLUDE: vol.Schema({ - vol.Optional(CONF_ENTITIES, default=[]): cv.entity_ids, - vol.Optional(CONF_DOMAINS, default=[]): - vol.All(cv.ensure_list, [cv.string]) - }), - CONF_INCLUDE: vol.Schema({ - vol.Optional(CONF_ENTITIES, default=[]): cv.entity_ids, - vol.Optional(CONF_DOMAINS, default=[]): - vol.All(cv.ensure_list, [cv.string]) - }) - }), + DOMAIN: recorder.FILTER_SCHEMA, }, extra=vol.ALLOW_EXTRA) SIGNIFICANT_DOMAINS = ('thermostat', 'climate') IGNORE_DOMAINS = ('zone', 'scene',) -def last_5_states(entity_id): - """Return the last 5 states for entity_id.""" - entity_id = entity_id.lower() - - states = recorder.get_model('States') - return recorder.execute( - recorder.query('States').filter( - (states.entity_id == entity_id) & - (states.last_changed == states.last_updated) - ).order_by(states.state_id.desc()).limit(5)) +def last_recorder_run(): + """Retireve the last closed recorder run from the DB.""" + rec_runs = recorder.get_model('RecorderRuns') + with recorder.session_scope() as session: + res = recorder.query(rec_runs).order_by(rec_runs.end.desc()).first() + if res is None: + return None + session.expunge(res) + return res def get_significant_states(start_time, end_time=None, entity_id=None, @@ -91,7 +78,7 @@ def get_significant_states(start_time, end_time=None, entity_id=None, def state_changes_during_period(start_time, end_time=None, entity_id=None): """Return states changes during UTC period start_time - end_time.""" states = recorder.get_model('States') - query = recorder.query('States').filter( + query = recorder.query(states).filter( (states.last_changed == states.last_updated) & (states.last_changed > start_time)) @@ -132,7 +119,7 @@ def get_states(utc_point_in_time, entity_ids=None, run=None, filters=None): most_recent_state_ids = most_recent_state_ids.group_by( states.entity_id).subquery() - query = recorder.query('States').join(most_recent_state_ids, and_( + query = recorder.query(states).join(most_recent_state_ids, and_( states.state_id == most_recent_state_ids.c.max_state_id)) for state in recorder.execute(query): @@ -185,27 +172,13 @@ def setup(hass, config): filters.included_entities = include[CONF_ENTITIES] filters.included_domains = include[CONF_DOMAINS] - hass.http.register_view(Last5StatesView) + recorder.get_instance() hass.http.register_view(HistoryPeriodView(filters)) register_built_in_panel(hass, 'history', 'History', 'mdi:poll-box') return True -class Last5StatesView(HomeAssistantView): - """Handle last 5 state view requests.""" - - url = '/api/history/entity/{entity_id}/recent_states' - name = 'api:history:entity-recent-states' - - @asyncio.coroutine - def get(self, request, entity_id): - """Retrieve last 5 states of entity.""" - result = yield from request.app['hass'].loop.run_in_executor( - None, last_5_states, entity_id) - return self.json(result) - - class HistoryPeriodView(HomeAssistantView): """Handle history period requests.""" diff --git a/homeassistant/components/input_boolean.py b/homeassistant/components/input_boolean.py index 16b2d365976..1817181b184 100644 --- a/homeassistant/components/input_boolean.py +++ b/homeassistant/components/input_boolean.py @@ -15,6 +15,7 @@ from homeassistant.const import ( import homeassistant.helpers.config_validation as cv from homeassistant.helpers.entity import ToggleEntity from homeassistant.helpers.entity_component import EntityComponent +from homeassistant.helpers.restore_state import async_get_last_state DOMAIN = 'input_boolean' @@ -139,6 +140,14 @@ class InputBoolean(ToggleEntity): """Return true if entity is on.""" return self._state + @asyncio.coroutine + def async_added_to_hass(self): + """Called when entity about to be added to hass.""" + state = yield from async_get_last_state(self.hass, self.entity_id) + if not state: + return + self._state = state.state == 'on' + @asyncio.coroutine def async_turn_on(self, **kwargs): """Turn the entity on.""" diff --git a/homeassistant/components/light/__init__.py b/homeassistant/components/light/__init__.py index 5c3e7f4d177..05002788207 100644 --- a/homeassistant/components/light/__init__.py +++ b/homeassistant/components/light/__init__.py @@ -22,6 +22,7 @@ from homeassistant.helpers.entity import ToggleEntity from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.config_validation import PLATFORM_SCHEMA # noqa import homeassistant.helpers.config_validation as cv +from homeassistant.helpers.restore_state import async_restore_state import homeassistant.util.color as color_util from homeassistant.util.async import run_callback_threadsafe @@ -126,6 +127,14 @@ PROFILE_SCHEMA = vol.Schema( _LOGGER = logging.getLogger(__name__) +def extract_info(state): + """Extract light parameters from a state object.""" + params = {key: state.attributes[key] for key in PROP_TO_ATTR + if key in state.attributes} + params['is_on'] = state.state == STATE_ON + return params + + def is_on(hass, entity_id=None): """Return if the lights are on based on the statemachine.""" entity_id = entity_id or ENTITY_ID_ALL_LIGHTS @@ -369,3 +378,9 @@ class Light(ToggleEntity): def supported_features(self): """Flag supported features.""" return 0 + + @asyncio.coroutine + def async_added_to_hass(self): + """Component added, restore_state using platforms.""" + if hasattr(self, 'async_restore_state'): + yield from async_restore_state(self, extract_info) diff --git a/homeassistant/components/light/demo.py b/homeassistant/components/light/demo.py index 068efbbfe5f..6482e31fbaa 100644 --- a/homeassistant/components/light/demo.py +++ b/homeassistant/components/light/demo.py @@ -4,6 +4,7 @@ Demo light platform that implements lights. For more details about this platform, please refer to the documentation https://home-assistant.io/components/demo/ """ +import asyncio import random from homeassistant.components.light import ( @@ -149,3 +150,26 @@ class DemoLight(Light): # As we have disabled polling, we need to inform # Home Assistant about updates in our state ourselves. self.schedule_update_ha_state() + + @asyncio.coroutine + def async_restore_state(self, is_on, **kwargs): + """Restore the demo state.""" + self._state = is_on + + if 'brightness' in kwargs: + self._brightness = kwargs['brightness'] + + if 'color_temp' in kwargs: + self._ct = kwargs['color_temp'] + + if 'rgb_color' in kwargs: + self._rgb = kwargs['rgb_color'] + + if 'xy_color' in kwargs: + self._xy_color = kwargs['xy_color'] + + if 'white_value' in kwargs: + self._white = kwargs['white_value'] + + if 'effect' in kwargs: + self._effect = kwargs['effect'] diff --git a/homeassistant/components/recorder/__init__.py b/homeassistant/components/recorder/__init__.py index b227a8ce76a..c7a81cafb6f 100644 --- a/homeassistant/components/recorder/__init__.py +++ b/homeassistant/components/recorder/__init__.py @@ -22,6 +22,7 @@ from homeassistant.const import ( ATTR_ENTITY_ID, CONF_ENTITIES, CONF_EXCLUDE, CONF_DOMAINS, CONF_INCLUDE, EVENT_HOMEASSISTANT_START, EVENT_HOMEASSISTANT_STOP, EVENT_STATE_CHANGED, EVENT_TIME_CHANGED, MATCH_ALL) +from homeassistant.exceptions import HomeAssistantError import homeassistant.helpers.config_validation as cv from homeassistant.helpers.event import async_track_time_interval from homeassistant.helpers.typing import ConfigType, QueryType @@ -42,36 +43,35 @@ CONNECT_RETRY_WAIT = 10 QUERY_RETRY_WAIT = 0.1 ERROR_QUERY = "Error during query: %s" +FILTER_SCHEMA = vol.Schema({ + vol.Optional(CONF_EXCLUDE, default={}): vol.Schema({ + vol.Optional(CONF_ENTITIES, default=[]): cv.entity_ids, + vol.Optional(CONF_DOMAINS, default=[]): + vol.All(cv.ensure_list, [cv.string]) + }), + vol.Optional(CONF_INCLUDE, default={}): vol.Schema({ + vol.Optional(CONF_ENTITIES, default=[]): cv.entity_ids, + vol.Optional(CONF_DOMAINS, default=[]): + vol.All(cv.ensure_list, [cv.string]) + }) +}) + CONFIG_SCHEMA = vol.Schema({ - DOMAIN: vol.Schema({ + DOMAIN: FILTER_SCHEMA.extend({ vol.Optional(CONF_PURGE_DAYS): vol.All(vol.Coerce(int), vol.Range(min=1)), vol.Optional(CONF_DB_URL): cv.string, - vol.Optional(CONF_EXCLUDE, default={}): vol.Schema({ - vol.Optional(CONF_ENTITIES, default=[]): cv.entity_ids, - vol.Optional(CONF_DOMAINS, default=[]): - vol.All(cv.ensure_list, [cv.string]) - }), - vol.Optional(CONF_INCLUDE, default={}): vol.Schema({ - vol.Optional(CONF_ENTITIES, default=[]): cv.entity_ids, - vol.Optional(CONF_DOMAINS, default=[]): - vol.All(cv.ensure_list, [cv.string]) - }) }) }, extra=vol.ALLOW_EXTRA) _INSTANCE = None # type: Any _LOGGER = logging.getLogger(__name__) -# These classes will be populated during setup() -# scoped_session, in the same thread session_scope() stays the same -_SESSION = None - @contextmanager def session_scope(): """Provide a transactional scope around a series of operations.""" - session = _SESSION() + session = _INSTANCE.get_session() try: yield session session.commit() @@ -83,15 +83,28 @@ def session_scope(): session.close() +def get_instance() -> None: + """Throw error if recorder not initialized.""" + if _INSTANCE is None: + raise RuntimeError("Recorder not initialized.") + + ident = _INSTANCE.hass.loop.__dict__.get("_thread_ident") + if ident is not None and ident == threading.get_ident(): + raise RuntimeError('Cannot be called from within the event loop') + + _wait(_INSTANCE.db_ready, "Database not ready") + + return _INSTANCE + + # pylint: disable=invalid-sequence-index def execute(qry: QueryType) -> List[Any]: """Query the database and convert the objects to HA native form. This method also retries a few times in the case of stale connections. """ - _verify_instance() - - import sqlalchemy.exc + get_instance() + from sqlalchemy.exc import SQLAlchemyError with session_scope() as session: for _ in range(0, RETRIES): try: @@ -99,7 +112,7 @@ def execute(qry: QueryType) -> List[Any]: row for row in (row.to_native() for row in qry) if row is not None] - except sqlalchemy.exc.SQLAlchemyError as err: + except SQLAlchemyError as err: _LOGGER.error(ERROR_QUERY, err) session.rollback() time.sleep(QUERY_RETRY_WAIT) @@ -111,13 +124,13 @@ def run_information(point_in_time: Optional[datetime]=None): There is also the run that covers point_in_time. """ - _verify_instance() + ins = get_instance() recorder_runs = get_model('RecorderRuns') - if point_in_time is None or point_in_time > _INSTANCE.recording_start: + if point_in_time is None or point_in_time > ins.recording_start: return recorder_runs( end=None, - start=_INSTANCE.recording_start, + start=ins.recording_start, closed_incorrect=False) with session_scope() as session: @@ -148,17 +161,19 @@ def setup(hass: HomeAssistant, config: ConfigType) -> bool: exclude = config.get(DOMAIN, {}).get(CONF_EXCLUDE, {}) _INSTANCE = Recorder(hass, purge_days=purge_days, uri=db_url, include=include, exclude=exclude) + _INSTANCE.start() return True -def query(model_name: Union[str, Any], *args) -> QueryType: +def query(model_name: Union[str, Any], session=None, *args) -> QueryType: """Helper to return a query handle.""" - _verify_instance() + if session is None: + session = get_instance().get_session() if isinstance(model_name, str): - return _SESSION().query(get_model(model_name), *args) - return _SESSION().query(model_name, *args) + return session.query(get_model(model_name), *args) + return session.query(model_name, *args) def get_model(model_name: str) -> Any: @@ -185,6 +200,7 @@ class Recorder(threading.Thread): self.recording_start = dt_util.utcnow() self.db_url = uri self.db_ready = threading.Event() + self.start_recording = threading.Event() self.engine = None # type: Any self._run = None # type: Any @@ -195,23 +211,26 @@ class Recorder(threading.Thread): def start_recording(event): """Start recording.""" - self.start() + self.start_recording.set() hass.bus.listen_once(EVENT_HOMEASSISTANT_START, start_recording) hass.bus.listen_once(EVENT_HOMEASSISTANT_STOP, self.shutdown) hass.bus.listen(MATCH_ALL, self.event_listener) + self.get_session = None + def run(self): """Start processing events to save.""" from homeassistant.components.recorder.models import Events, States - import sqlalchemy.exc + from sqlalchemy.exc import SQLAlchemyError while True: try: self._setup_connection() self._setup_run() + self.db_ready.set() break - except sqlalchemy.exc.SQLAlchemyError as err: + except SQLAlchemyError as err: _LOGGER.error("Error during connection setup: %s (retrying " "in %s seconds)", err, CONNECT_RETRY_WAIT) time.sleep(CONNECT_RETRY_WAIT) @@ -220,6 +239,8 @@ class Recorder(threading.Thread): async_track_time_interval( self.hass, self._purge_old_data, timedelta(days=2)) + _wait(self.start_recording, "Waiting to start recording") + while True: event = self.queue.get() @@ -275,10 +296,9 @@ class Recorder(threading.Thread): def shutdown(self, event): """Tell the recorder to shut down.""" global _INSTANCE # pylint: disable=global-statement - _INSTANCE = None - self.queue.put(None) self.join() + _INSTANCE = None def block_till_done(self): """Block till all events processed.""" @@ -286,15 +306,10 @@ class Recorder(threading.Thread): def block_till_db_ready(self): """Block until the database session is ready.""" - self.db_ready.wait(10) - while not self.db_ready.is_set(): - _LOGGER.warning('Database not ready, waiting another 10 seconds.') - self.db_ready.wait(10) + _wait(self.db_ready, "Database not ready") def _setup_connection(self): """Ensure database is ready to fly.""" - global _SESSION # pylint: disable=invalid-name,global-statement - import homeassistant.components.recorder.models as models from sqlalchemy import create_engine from sqlalchemy.orm import scoped_session @@ -312,9 +327,8 @@ class Recorder(threading.Thread): models.Base.metadata.create_all(self.engine) session_factory = sessionmaker(bind=self.engine) - _SESSION = scoped_session(session_factory) + self.get_session = scoped_session(session_factory) self._migrate_schema() - self.db_ready.set() def _migrate_schema(self): """Check if the schema needs to be upgraded.""" @@ -396,16 +410,16 @@ class Recorder(threading.Thread): def _close_connection(self): """Close the connection.""" - global _SESSION # pylint: disable=invalid-name,global-statement self.engine.dispose() self.engine = None - _SESSION = None + self.get_session = None def _setup_run(self): """Log the start of the current run.""" recorder_runs = get_model('RecorderRuns') with session_scope() as session: - for run in query('RecorderRuns').filter_by(end=None): + for run in query( + recorder_runs, session=session).filter_by(end=None): run.closed_incorrect = True run.end = self.recording_start _LOGGER.warning("Ended unfinished session (id=%s from %s)", @@ -482,13 +496,13 @@ class Recorder(threading.Thread): return False -def _verify_instance() -> None: - """Throw error if recorder not initialized.""" - if _INSTANCE is None: - raise RuntimeError("Recorder not initialized.") - - ident = _INSTANCE.hass.loop.__dict__.get("_thread_ident") - if ident is not None and ident == threading.get_ident(): - raise RuntimeError('Cannot be called from within the event loop') - - _INSTANCE.block_till_db_ready() +def _wait(event, message): + """Event wait helper.""" + for retry in (10, 20, 30): + event.wait(10) + if event.is_set(): + return + msg = message + " ({} seconds)".format(retry) + _LOGGER.warning(msg) + if not event.is_set(): + raise HomeAssistantError(msg) diff --git a/homeassistant/components/sensor/history_stats.py b/homeassistant/components/sensor/history_stats.py index 8d03f4754e0..b019e6745fb 100644 --- a/homeassistant/components/sensor/history_stats.py +++ b/homeassistant/components/sensor/history_stats.py @@ -199,7 +199,7 @@ class HistoryStatsSensor(Entity): if self._start is not None: try: start_rendered = self._start.render() - except TemplateError as ex: + except (TemplateError, TypeError) as ex: HistoryStatsHelper.handle_template_exception(ex, 'start') return start = dt_util.parse_datetime(start_rendered) @@ -216,7 +216,7 @@ class HistoryStatsSensor(Entity): if self._end is not None: try: end_rendered = self._end.render() - except TemplateError as ex: + except (TemplateError, TypeError) as ex: HistoryStatsHelper.handle_template_exception(ex, 'end') return end = dt_util.parse_datetime(end_rendered) diff --git a/homeassistant/helpers/entity.py b/homeassistant/helpers/entity.py index 0705f60a9b6..a3bf1a03386 100644 --- a/homeassistant/helpers/entity.py +++ b/homeassistant/helpers/entity.py @@ -288,7 +288,7 @@ class Entity(object): self.hass.add_job(self.async_update_ha_state(force_refresh)) def remove(self) -> None: - """Remove entitiy from HASS.""" + """Remove entity from HASS.""" run_coroutine_threadsafe( self.async_remove(), self.hass.loop ).result() diff --git a/homeassistant/helpers/entity_component.py b/homeassistant/helpers/entity_component.py index 949dc578ca3..ad88045039f 100644 --- a/homeassistant/helpers/entity_component.py +++ b/homeassistant/helpers/entity_component.py @@ -202,6 +202,10 @@ class EntityComponent(object): 'Invalid entity id: {}'.format(entity.entity_id)) self.entities[entity.entity_id] = entity + + if hasattr(entity, 'async_added_to_hass'): + yield from entity.async_added_to_hass() + yield from entity.async_update_ha_state() return True diff --git a/homeassistant/helpers/restore_state.py b/homeassistant/helpers/restore_state.py new file mode 100644 index 00000000000..dfed0f52413 --- /dev/null +++ b/homeassistant/helpers/restore_state.py @@ -0,0 +1,82 @@ +"""Support for restoring entity states on startup.""" +import asyncio +import logging +from datetime import timedelta + +from homeassistant.core import HomeAssistant, CoreState, callback +from homeassistant.const import EVENT_HOMEASSISTANT_START +from homeassistant.components.history import get_states, last_recorder_run +from homeassistant.components.recorder import DOMAIN as _RECORDER +import homeassistant.util.dt as dt_util + +_LOGGER = logging.getLogger(__name__) + +DATA_RESTORE_CACHE = 'restore_state_cache' +_LOCK = 'restore_lock' + + +def _load_restore_cache(hass: HomeAssistant): + """Load the restore cache to be used by other components.""" + @callback + def remove_cache(event): + """Remove the states cache.""" + hass.data.pop(DATA_RESTORE_CACHE, None) + + hass.bus.listen_once(EVENT_HOMEASSISTANT_START, remove_cache) + + last_run = last_recorder_run() + + if last_run is None or last_run.end is None: + _LOGGER.debug('Not creating cache - no suitable last run found: %s', + last_run) + hass.data[DATA_RESTORE_CACHE] = {} + return + + last_end_time = last_run.end - timedelta(seconds=1) + # Unfortunately the recorder_run model do not return offset-aware time + last_end_time = last_end_time.replace(tzinfo=dt_util.UTC) + _LOGGER.debug("Last run: %s - %s", last_run.start, last_end_time) + + states = get_states(last_end_time, run=last_run) + + # Cache the states + hass.data[DATA_RESTORE_CACHE] = { + state.entity_id: state for state in states} + _LOGGER.debug('Created cache with %s', list(hass.data[DATA_RESTORE_CACHE])) + + +@asyncio.coroutine +def async_get_last_state(hass, entity_id: str): + """Helper to restore state.""" + if (_RECORDER not in hass.config.components or + hass.state != CoreState.starting): + return None + + if DATA_RESTORE_CACHE in hass.data: + return hass.data[DATA_RESTORE_CACHE].get(entity_id) + + if _LOCK not in hass.data: + hass.data[_LOCK] = asyncio.Lock(loop=hass.loop) + + with (yield from hass.data[_LOCK]): + if DATA_RESTORE_CACHE not in hass.data: + yield from hass.loop.run_in_executor( + None, _load_restore_cache, hass) + + return hass.data[DATA_RESTORE_CACHE].get(entity_id) + + +@asyncio.coroutine +def async_restore_state(entity, extract_info): + """Helper to call entity.async_restore_state with cached info.""" + if entity.hass.state != CoreState.starting: + _LOGGER.debug("Not restoring state: State is not starting: %s", + entity.hass.state) + return + + state = yield from async_get_last_state(entity.hass, entity.entity_id) + + if not state: + return + + yield from entity.async_restore_state(**extract_info(state)) diff --git a/homeassistant/loader.py b/homeassistant/loader.py index 243de03ec92..60ba924f46c 100644 --- a/homeassistant/loader.py +++ b/homeassistant/loader.py @@ -197,8 +197,8 @@ def load_order_components(components: Sequence[str]) -> OrderedSet: load_order.update(comp_load_order) # Push some to first place in load order - for comp in ('mqtt_eventstream', 'mqtt', 'logger', - 'recorder', 'introduction'): + for comp in ('mqtt_eventstream', 'mqtt', 'recorder', + 'introduction', 'logger'): if comp in load_order: load_order.promote(comp) diff --git a/tests/common.py b/tests/common.py index 6dd1ecd4586..bba53243a44 100644 --- a/tests/common.py +++ b/tests/common.py @@ -22,7 +22,7 @@ from homeassistant.const import ( STATE_ON, STATE_OFF, DEVICE_DEFAULT_NAME, EVENT_TIME_CHANGED, EVENT_STATE_CHANGED, EVENT_PLATFORM_DISCOVERED, ATTR_SERVICE, ATTR_DISCOVERED, SERVER_PORT) -from homeassistant.components import sun, mqtt +from homeassistant.components import sun, mqtt, recorder from homeassistant.components.http.auth import auth_middleware from homeassistant.components.http.const import ( KEY_USE_X_FORWARDED_FOR, KEY_BANS_ENABLED, KEY_TRUSTED_NETWORKS) @@ -452,3 +452,31 @@ def assert_setup_component(count, domain=None): res_len = 0 if res is None else len(res) assert res_len == count, 'setup_component failed, expected {} got {}: {}' \ .format(count, res_len, res) + + +def init_recorder_component(hass, add_config=None, db_ready_callback=None): + """Initialize the recorder.""" + config = dict(add_config) if add_config else {} + config[recorder.CONF_DB_URL] = 'sqlite://' # In memory DB + + saved_recorder = recorder.Recorder + + class Recorder2(saved_recorder): + """Recorder with a callback after db_ready.""" + + def _setup_connection(self): + """Setup the connection and run the callback.""" + super(Recorder2, self)._setup_connection() + if db_ready_callback: + _LOGGER.debug('db_ready_callback start (db_ready not set,' + 'never use get_instance in the callback)') + db_ready_callback() + _LOGGER.debug('db_ready_callback completed') + + with patch('homeassistant.components.recorder.Recorder', + side_effect=Recorder2): + assert setup_component(hass, recorder.DOMAIN, + {recorder.DOMAIN: config}) + assert recorder.DOMAIN in hass.config.components + recorder.get_instance().block_till_db_ready() + _LOGGER.info("In-memory recorder successfully started") diff --git a/tests/components/light/test_demo.py b/tests/components/light/test_demo.py index aa8c8d9f1e8..f8b46579187 100644 --- a/tests/components/light/test_demo.py +++ b/tests/components/light/test_demo.py @@ -1,17 +1,20 @@ """The tests for the demo light component.""" # pylint: disable=protected-access +import asyncio import unittest -from homeassistant.bootstrap import setup_component +from homeassistant.core import State, CoreState +from homeassistant.bootstrap import setup_component, async_setup_component import homeassistant.components.light as light +from homeassistant.helpers.restore_state import DATA_RESTORE_CACHE from tests.common import get_test_home_assistant ENTITY_LIGHT = 'light.bed_light' -class TestDemoClimate(unittest.TestCase): - """Test the demo climate hvac.""" +class TestDemoLight(unittest.TestCase): + """Test the demo light.""" # pylint: disable=invalid-name def setUp(self): @@ -60,3 +63,36 @@ class TestDemoClimate(unittest.TestCase): light.turn_off(self.hass, ENTITY_LIGHT) self.hass.block_till_done() self.assertFalse(light.is_on(self.hass, ENTITY_LIGHT)) + + +@asyncio.coroutine +def test_restore_state(hass): + """Test state gets restored.""" + hass.config.components.add('recorder') + hass.state = CoreState.starting + hass.data[DATA_RESTORE_CACHE] = { + 'light.bed_light': State('light.bed_light', 'on', { + 'brightness': 'value-brightness', + 'color_temp': 'value-color_temp', + 'rgb_color': 'value-rgb_color', + 'xy_color': 'value-xy_color', + 'white_value': 'value-white_value', + 'effect': 'value-effect', + }), + } + + yield from async_setup_component(hass, 'light', { + 'light': { + 'platform': 'demo', + }}) + + state = hass.states.get('light.bed_light') + assert state is not None + assert state.entity_id == 'light.bed_light' + assert state.state == 'on' + assert state.attributes.get('brightness') == 'value-brightness' + assert state.attributes.get('color_temp') == 'value-color_temp' + assert state.attributes.get('rgb_color') == 'value-rgb_color' + assert state.attributes.get('xy_color') == 'value-xy_color' + assert state.attributes.get('white_value') == 'value-white_value' + assert state.attributes.get('effect') == 'value-effect' diff --git a/tests/components/recorder/test_init.py b/tests/components/recorder/test_init.py index 0bfa3a20997..fa38a9d3784 100644 --- a/tests/components/recorder/test_init.py +++ b/tests/components/recorder/test_init.py @@ -11,8 +11,7 @@ from sqlalchemy import create_engine from homeassistant.core import callback from homeassistant.const import MATCH_ALL from homeassistant.components import recorder -from homeassistant.bootstrap import setup_component -from tests.common import get_test_home_assistant +from tests.common import get_test_home_assistant, init_recorder_component from tests.components.recorder import models_original @@ -22,18 +21,15 @@ class BaseTestRecorder(unittest.TestCase): def setUp(self): # pylint: disable=invalid-name """Setup things to be run when tests are started.""" self.hass = get_test_home_assistant() - db_uri = 'sqlite://' # In memory DB - setup_component(self.hass, recorder.DOMAIN, { - recorder.DOMAIN: {recorder.CONF_DB_URL: db_uri}}) + init_recorder_component(self.hass) self.hass.start() - recorder._verify_instance() - recorder._INSTANCE.block_till_done() + recorder.get_instance().block_till_done() def tearDown(self): # pylint: disable=invalid-name """Stop everything that was started.""" - recorder._INSTANCE.shutdown(None) self.hass.stop() - assert recorder._INSTANCE is None + with self.assertRaises(RuntimeError): + recorder.get_instance() def _add_test_states(self): """Add multiple states to the db for testing.""" @@ -228,7 +224,7 @@ class TestMigrateRecorder(BaseTestRecorder): @patch('sqlalchemy.create_engine', new=create_engine_test) @patch('homeassistant.components.recorder.Recorder._migrate_schema') - def setUp(self, migrate): # pylint: disable=invalid-name + def setUp(self, migrate): # pylint: disable=invalid-name,arguments-differ """Setup things to be run when tests are started. create_engine is patched to create a db that starts with the old @@ -261,16 +257,12 @@ def hass_recorder(): """HASS fixture with in-memory recorder.""" hass = get_test_home_assistant() - def setup_recorder(config={}): + def setup_recorder(config=None): """Setup with params.""" - db_uri = 'sqlite://' # In memory DB - conf = {recorder.CONF_DB_URL: db_uri} - conf.update(config) - assert setup_component(hass, recorder.DOMAIN, {recorder.DOMAIN: conf}) + init_recorder_component(hass, config) hass.start() hass.block_till_done() - recorder._verify_instance() - recorder._INSTANCE.block_till_done() + recorder.get_instance().block_till_done() return hass yield setup_recorder @@ -352,12 +344,12 @@ def test_recorder_errors_exceptions(hass_recorder): \ # Verify the instance fails before setup with pytest.raises(RuntimeError): - recorder._verify_instance() + recorder.get_instance() # Setup the recorder hass_recorder() - recorder._verify_instance() + recorder.get_instance() # Verify session scope raises (and prints) an exception with patch('homeassistant.components.recorder._LOGGER.error') as e_mock, \ diff --git a/tests/components/sensor/test_history_stats.py b/tests/components/sensor/test_history_stats.py index d7d8b516525..d4f1cbcbe9a 100644 --- a/tests/components/sensor/test_history_stats.py +++ b/tests/components/sensor/test_history_stats.py @@ -1,16 +1,17 @@ """The test for the History Statistics sensor platform.""" # pylint: disable=protected-access -import unittest from datetime import timedelta +import unittest from unittest.mock import patch -import homeassistant.components.recorder as recorder -import homeassistant.core as ha -import homeassistant.util.dt as dt_util from homeassistant.bootstrap import setup_component +import homeassistant.components.recorder as recorder from homeassistant.components.sensor.history_stats import HistoryStatsSensor +import homeassistant.core as ha from homeassistant.helpers.template import Template -from tests.common import get_test_home_assistant +import homeassistant.util.dt as dt_util + +from tests.common import init_recorder_component, get_test_home_assistant class TestHistoryStatsSensor(unittest.TestCase): @@ -204,12 +205,8 @@ class TestHistoryStatsSensor(unittest.TestCase): def init_recorder(self): """Initialize the recorder.""" - db_uri = 'sqlite://' - with patch('homeassistant.core.Config.path', return_value=db_uri): - setup_component(self.hass, recorder.DOMAIN, { - "recorder": { - "db_url": db_uri}}) + init_recorder_component(self.hass) self.hass.start() - recorder._INSTANCE.block_till_db_ready() + recorder.get_instance().block_till_db_ready() self.hass.block_till_done() - recorder._INSTANCE.block_till_done() + recorder.get_instance().block_till_done() diff --git a/tests/components/test_history.py b/tests/components/test_history.py index a79f56b0829..65870d1450f 100644 --- a/tests/components/test_history.py +++ b/tests/components/test_history.py @@ -1,5 +1,5 @@ """The tests the History component.""" -# pylint: disable=protected-access +# pylint: disable=protected-access,invalid-name from datetime import timedelta import unittest from unittest.mock import patch, sentinel @@ -10,68 +10,47 @@ import homeassistant.util.dt as dt_util from homeassistant.components import history, recorder from tests.common import ( - mock_http_component, mock_state_change_event, get_test_home_assistant) + init_recorder_component, mock_http_component, mock_state_change_event, + get_test_home_assistant) class TestComponentHistory(unittest.TestCase): """Test History component.""" - # pylint: disable=invalid-name - def setUp(self): + def setUp(self): # pylint: disable=invalid-name """Setup things to be run when tests are started.""" self.hass = get_test_home_assistant() - # pylint: disable=invalid-name - def tearDown(self): + def tearDown(self): # pylint: disable=invalid-name """Stop everything that was started.""" self.hass.stop() def init_recorder(self): """Initialize the recorder.""" - db_uri = 'sqlite://' - with patch('homeassistant.core.Config.path', return_value=db_uri): - setup_component(self.hass, recorder.DOMAIN, { - "recorder": { - "db_url": db_uri}}) + init_recorder_component(self.hass) self.hass.start() - recorder._INSTANCE.block_till_db_ready() + recorder.get_instance().block_till_db_ready() self.wait_recording_done() def wait_recording_done(self): """Block till recording is done.""" self.hass.block_till_done() - recorder._INSTANCE.block_till_done() + recorder.get_instance().block_till_done() def test_setup(self): """Test setup method of history.""" mock_http_component(self.hass) config = history.CONFIG_SCHEMA({ - ha.DOMAIN: {}, - history.DOMAIN: {history.CONF_INCLUDE: { + # ha.DOMAIN: {}, + history.DOMAIN: { + history.CONF_INCLUDE: { history.CONF_DOMAINS: ['media_player'], history.CONF_ENTITIES: ['thermostat.test']}, history.CONF_EXCLUDE: { history.CONF_DOMAINS: ['thermostat'], history.CONF_ENTITIES: ['media_player.test']}}}) - self.assertTrue(setup_component(self.hass, history.DOMAIN, config)) - - def test_last_5_states(self): - """Test retrieving the last 5 states.""" self.init_recorder() - states = [] - - entity_id = 'test.last_5_states' - - for i in range(7): - self.hass.states.set(entity_id, "State {}".format(i)) - - self.wait_recording_done() - - if i > 1: - states.append(self.hass.states.get(entity_id)) - - self.assertEqual( - list(reversed(states)), history.last_5_states(entity_id)) + self.assertTrue(setup_component(self.hass, history.DOMAIN, config)) def test_get_states(self): """Test getting states at a specific point in time.""" @@ -121,6 +100,7 @@ class TestComponentHistory(unittest.TestCase): entity_id = 'media_player.test' def set_state(state): + """Set the state.""" self.hass.states.set(entity_id, state) self.wait_recording_done() return self.hass.states.get(entity_id) @@ -311,7 +291,8 @@ class TestComponentHistory(unittest.TestCase): config = history.CONFIG_SCHEMA({ ha.DOMAIN: {}, - history.DOMAIN: {history.CONF_INCLUDE: { + history.DOMAIN: { + history.CONF_INCLUDE: { history.CONF_DOMAINS: ['media_player']}, history.CONF_EXCLUDE: { history.CONF_DOMAINS: ['media_player']}}}) @@ -332,7 +313,8 @@ class TestComponentHistory(unittest.TestCase): config = history.CONFIG_SCHEMA({ ha.DOMAIN: {}, - history.DOMAIN: {history.CONF_INCLUDE: { + history.DOMAIN: { + history.CONF_INCLUDE: { history.CONF_ENTITIES: ['media_player.test']}, history.CONF_EXCLUDE: { history.CONF_ENTITIES: ['media_player.test']}}}) @@ -351,7 +333,8 @@ class TestComponentHistory(unittest.TestCase): config = history.CONFIG_SCHEMA({ ha.DOMAIN: {}, - history.DOMAIN: {history.CONF_INCLUDE: { + history.DOMAIN: { + history.CONF_INCLUDE: { history.CONF_DOMAINS: ['media_player'], history.CONF_ENTITIES: ['thermostat.test']}, history.CONF_EXCLUDE: { @@ -359,7 +342,8 @@ class TestComponentHistory(unittest.TestCase): history.CONF_ENTITIES: ['media_player.test']}}}) self.check_significant_states(zero, four, states, config) - def check_significant_states(self, zero, four, states, config): + def check_significant_states(self, zero, four, states, config): \ + # pylint: disable=no-self-use """Check if significant states are retrieved.""" filters = history.Filters() exclude = config[history.DOMAIN].get(history.CONF_EXCLUDE) @@ -390,6 +374,7 @@ class TestComponentHistory(unittest.TestCase): script_c = 'script.can_cancel_this_one' def set_state(entity_id, state, **kwargs): + """Set the state.""" self.hass.states.set(entity_id, state, **kwargs) self.wait_recording_done() return self.hass.states.get(entity_id) diff --git a/tests/components/test_input_boolean.py b/tests/components/test_input_boolean.py index 1e261ccbcc8..c22c431ed03 100644 --- a/tests/components/test_input_boolean.py +++ b/tests/components/test_input_boolean.py @@ -1,15 +1,18 @@ """The tests for the input_boolean component.""" # pylint: disable=protected-access +import asyncio import unittest import logging from tests.common import get_test_home_assistant -from homeassistant.bootstrap import setup_component +from homeassistant.core import CoreState, State +from homeassistant.bootstrap import setup_component, async_setup_component from homeassistant.components.input_boolean import ( DOMAIN, is_on, toggle, turn_off, turn_on) from homeassistant.const import ( STATE_ON, STATE_OFF, ATTR_ICON, ATTR_FRIENDLY_NAME) +from homeassistant.helpers.restore_state import DATA_RESTORE_CACHE _LOGGER = logging.getLogger(__name__) @@ -103,3 +106,30 @@ class TestInputBoolean(unittest.TestCase): self.assertEqual('Hello World', state_2.attributes.get(ATTR_FRIENDLY_NAME)) self.assertEqual('mdi:work', state_2.attributes.get(ATTR_ICON)) + + +@asyncio.coroutine +def test_restore_state(hass): + """Ensure states are restored on startup.""" + hass.data[DATA_RESTORE_CACHE] = { + 'input_boolean.b1': State('input_boolean.b1', 'on'), + 'input_boolean.b2': State('input_boolean.b2', 'off'), + 'input_boolean.b3': State('input_boolean.b3', 'on'), + } + + hass.state = CoreState.starting + hass.config.components.add('recorder') + + yield from async_setup_component(hass, DOMAIN, { + DOMAIN: { + 'b1': None, + 'b2': None, + }}) + + state = hass.states.get('input_boolean.b1') + assert state + assert state.state == 'on' + + state = hass.states.get('input_boolean.b2') + assert state + assert state.state == 'off' diff --git a/tests/components/test_logbook.py b/tests/components/test_logbook.py index b6583ba3536..69497ef8388 100644 --- a/tests/components/test_logbook.py +++ b/tests/components/test_logbook.py @@ -1,5 +1,6 @@ """The tests for the logbook component.""" -# pylint: disable=protected-access +# pylint: disable=protected-access,invalid-name +import logging from datetime import timedelta import unittest from unittest.mock import patch @@ -13,7 +14,11 @@ import homeassistant.util.dt as dt_util from homeassistant.components import logbook from homeassistant.bootstrap import setup_component -from tests.common import mock_http_component, get_test_home_assistant +from tests.common import ( + mock_http_component, init_recorder_component, get_test_home_assistant) + + +_LOGGER = logging.getLogger(__name__) class TestComponentLogbook(unittest.TestCase): @@ -24,12 +29,14 @@ class TestComponentLogbook(unittest.TestCase): def setUp(self): """Setup things to be run when tests are started.""" self.hass = get_test_home_assistant() + init_recorder_component(self.hass) # Force an in memory DB mock_http_component(self.hass) self.hass.config.components |= set(['frontend', 'recorder', 'api']) with patch('homeassistant.components.logbook.' 'register_built_in_panel'): assert setup_component(self.hass, logbook.DOMAIN, self.EMPTY_CONFIG) + self.hass.start() def tearDown(self): """Stop everything that was started.""" @@ -41,6 +48,7 @@ class TestComponentLogbook(unittest.TestCase): @ha.callback def event_listener(event): + """Append on event.""" calls.append(event) self.hass.bus.listen(logbook.EVENT_LOGBOOK_ENTRY, event_listener) @@ -72,6 +80,7 @@ class TestComponentLogbook(unittest.TestCase): @ha.callback def event_listener(event): + """Append on event.""" calls.append(event) self.hass.bus.listen(logbook.EVENT_LOGBOOK_ENTRY, event_listener) @@ -242,17 +251,17 @@ class TestComponentLogbook(unittest.TestCase): entity_id2 = 'sensor.blu' eventA = ha.Event(logbook.EVENT_LOGBOOK_ENTRY, { - logbook.ATTR_NAME: name, - logbook.ATTR_MESSAGE: message, - logbook.ATTR_DOMAIN: domain, - logbook.ATTR_ENTITY_ID: entity_id, - }) + logbook.ATTR_NAME: name, + logbook.ATTR_MESSAGE: message, + logbook.ATTR_DOMAIN: domain, + logbook.ATTR_ENTITY_ID: entity_id, + }) eventB = ha.Event(logbook.EVENT_LOGBOOK_ENTRY, { - logbook.ATTR_NAME: name, - logbook.ATTR_MESSAGE: message, - logbook.ATTR_DOMAIN: domain, - logbook.ATTR_ENTITY_ID: entity_id2, - }) + logbook.ATTR_NAME: name, + logbook.ATTR_MESSAGE: message, + logbook.ATTR_DOMAIN: domain, + logbook.ATTR_ENTITY_ID: entity_id2, + }) config = logbook.CONFIG_SCHEMA({ ha.DOMAIN: {}, @@ -532,7 +541,8 @@ class TestComponentLogbook(unittest.TestCase): def create_state_changed_event(self, event_time_fired, entity_id, state, attributes=None, last_changed=None, - last_updated=None): + last_updated=None): \ + # pylint: disable=no-self-use """Create state changed event.""" # Logbook only cares about state change events that # contain an old state but will not actually act on it. diff --git a/tests/helpers/test_restore_state.py b/tests/helpers/test_restore_state.py new file mode 100644 index 00000000000..02e374c8576 --- /dev/null +++ b/tests/helpers/test_restore_state.py @@ -0,0 +1,42 @@ +"""The tests for the Restore component.""" +import asyncio +from unittest.mock import patch, MagicMock + +from homeassistant.const import EVENT_HOMEASSISTANT_START +from homeassistant.core import CoreState, State +import homeassistant.util.dt as dt_util + +from homeassistant.helpers.restore_state import ( + async_get_last_state, DATA_RESTORE_CACHE) + + +@asyncio.coroutine +def test_caching_data(hass): + """Test that we cache data.""" + hass.config.components.add('recorder') + hass.state = CoreState.starting + + states = [ + State('input_boolean.b0', 'on'), + State('input_boolean.b1', 'on'), + State('input_boolean.b2', 'on'), + ] + + with patch('homeassistant.helpers.restore_state.last_recorder_run', + return_value=MagicMock(end=dt_util.utcnow())), \ + patch('homeassistant.helpers.restore_state.get_states', + return_value=states): + state = yield from async_get_last_state(hass, 'input_boolean.b1') + + assert DATA_RESTORE_CACHE in hass.data + assert hass.data[DATA_RESTORE_CACHE] == {st.entity_id: st for st in states} + + assert state is not None + assert state.entity_id == 'input_boolean.b1' + assert state.state == 'on' + + hass.bus.async_fire(EVENT_HOMEASSISTANT_START) + + yield from hass.async_block_till_done() + + assert DATA_RESTORE_CACHE not in hass.data