diff --git a/homeassistant/components/history.py b/homeassistant/components/history.py index 254115c55b1..5c68f767cd2 100644 --- a/homeassistant/components/history.py +++ b/homeassistant/components/history.py @@ -20,6 +20,7 @@ from homeassistant.components import recorder, script from homeassistant.components.frontend import register_built_in_panel from homeassistant.components.http import HomeAssistantView from homeassistant.const import ATTR_HIDDEN +from homeassistant.components.recorder.util import session_scope, execute _LOGGER = logging.getLogger(__name__) @@ -34,19 +35,20 @@ SIGNIFICANT_DOMAINS = ('thermostat', 'climate') IGNORE_DOMAINS = ('zone', 'scene',) -def last_recorder_run(): +def last_recorder_run(hass): """Retireve the last closed recorder run from the DB.""" - recorder.get_instance() - rec_runs = recorder.get_model('RecorderRuns') - with recorder.session_scope() as session: - res = recorder.query(rec_runs).order_by(rec_runs.end.desc()).first() + from homeassistant.components.recorder.models import RecorderRuns + + with session_scope(hass=hass) as session: + res = (session.query(RecorderRuns) + .order_by(RecorderRuns.end.desc()).first()) if res is None: return None session.expunge(res) return res -def get_significant_states(start_time, end_time=None, entity_id=None, +def get_significant_states(hass, start_time, end_time=None, entity_id=None, filters=None): """ Return states changes during UTC period start_time - end_time. @@ -55,50 +57,60 @@ def get_significant_states(start_time, end_time=None, entity_id=None, as well as all states from certain domains (for instance thermostat so that we get current temperature in our graphs). """ + from homeassistant.components.recorder.models import States + entity_ids = (entity_id.lower(), ) if entity_id is not None else None - states = recorder.get_model('States') - query = recorder.query(states).filter( - (states.domain.in_(SIGNIFICANT_DOMAINS) | - (states.last_changed == states.last_updated)) & - (states.last_updated > start_time)) - if filters: - query = filters.apply(query, entity_ids) - if end_time is not None: - query = query.filter(states.last_updated < end_time) + with session_scope(hass=hass) as session: + query = session.query(States).filter( + (States.domain.in_(SIGNIFICANT_DOMAINS) | + (States.last_changed == States.last_updated)) & + (States.last_updated > start_time)) - states = ( - state for state in recorder.execute( - query.order_by(states.entity_id, states.last_updated)) - if (_is_significant(state) and - not state.attributes.get(ATTR_HIDDEN, False))) + if filters: + query = filters.apply(query, entity_ids) - return states_to_json(states, start_time, entity_id, filters) + if end_time is not None: + query = query.filter(States.last_updated < end_time) + + states = ( + state for state in execute( + query.order_by(States.entity_id, States.last_updated)) + if (_is_significant(state) and + not state.attributes.get(ATTR_HIDDEN, False))) + + return states_to_json(hass, states, start_time, entity_id, filters) -def state_changes_during_period(start_time, end_time=None, entity_id=None): +def state_changes_during_period(hass, start_time, end_time=None, + entity_id=None): """Return states changes during UTC period start_time - end_time.""" - states = recorder.get_model('States') - query = recorder.query(states).filter( - (states.last_changed == states.last_updated) & - (states.last_changed > start_time)) + from homeassistant.components.recorder.models import States - if end_time is not None: - query = query.filter(states.last_updated < end_time) + with session_scope(hass=hass) as session: + query = session.query(States).filter( + (States.last_changed == States.last_updated) & + (States.last_changed > start_time)) - if entity_id is not None: - query = query.filter_by(entity_id=entity_id.lower()) + if end_time is not None: + query = query.filter(States.last_updated < end_time) - states = recorder.execute( - query.order_by(states.entity_id, states.last_updated)) + if entity_id is not None: + query = query.filter_by(entity_id=entity_id.lower()) - return states_to_json(states, start_time, entity_id) + states = execute( + query.order_by(States.entity_id, States.last_updated)) + + return states_to_json(hass, states, start_time, entity_id) -def get_states(utc_point_in_time, entity_ids=None, run=None, filters=None): +def get_states(hass, utc_point_in_time, entity_ids=None, run=None, + filters=None): """Return the states at a specific point in time.""" + from homeassistant.components.recorder.models import States + if run is None: - run = recorder.run_information(utc_point_in_time) + run = recorder.run_information(hass, utc_point_in_time) # History did not run before utc_point_in_time if run is None: @@ -106,29 +118,29 @@ def get_states(utc_point_in_time, entity_ids=None, run=None, filters=None): from sqlalchemy import and_, func - states = recorder.get_model('States') - most_recent_state_ids = recorder.query( - func.max(states.state_id).label('max_state_id') - ).filter( - (states.created >= run.start) & - (states.created < utc_point_in_time) & - (~states.domain.in_(IGNORE_DOMAINS))) - if filters: - most_recent_state_ids = filters.apply(most_recent_state_ids, - entity_ids) + with session_scope(hass=hass) as session: + most_recent_state_ids = session.query( + func.max(States.state_id).label('max_state_id') + ).filter( + (States.created >= run.start) & + (States.created < utc_point_in_time) & + (~States.domain.in_(IGNORE_DOMAINS))) - most_recent_state_ids = most_recent_state_ids.group_by( - states.entity_id).subquery() + if filters: + most_recent_state_ids = filters.apply(most_recent_state_ids, + entity_ids) - query = recorder.query(states).join(most_recent_state_ids, and_( - states.state_id == most_recent_state_ids.c.max_state_id)) + most_recent_state_ids = most_recent_state_ids.group_by( + States.entity_id).subquery() - for state in recorder.execute(query): - if not state.attributes.get(ATTR_HIDDEN, False): - yield state + query = session.query(States).join(most_recent_state_ids, and_( + States.state_id == most_recent_state_ids.c.max_state_id)) + + return [state for state in execute(query) + if not state.attributes.get(ATTR_HIDDEN, False)] -def states_to_json(states, start_time, entity_id, filters=None): +def states_to_json(hass, states, start_time, entity_id, filters=None): """Convert SQL results into JSON friendly data structure. This takes our state list and turns it into a JSON friendly data @@ -143,7 +155,7 @@ def states_to_json(states, start_time, entity_id, filters=None): entity_ids = [entity_id] if entity_id is not None else None # Get the states at the start time - for state in get_states(start_time, entity_ids, filters=filters): + for state in get_states(hass, start_time, entity_ids, filters=filters): state.last_changed = start_time state.last_updated = start_time result[state.entity_id].append(state) @@ -154,9 +166,9 @@ def states_to_json(states, start_time, entity_id, filters=None): return result -def get_state(utc_point_in_time, entity_id, run=None): +def get_state(hass, utc_point_in_time, entity_id, run=None): """Return a state at a specific point in time.""" - states = list(get_states(utc_point_in_time, (entity_id,), run)) + states = list(get_states(hass, utc_point_in_time, (entity_id,), run)) return states[0] if states else None @@ -173,7 +185,6 @@ def setup(hass, config): filters.included_entities = include[CONF_ENTITIES] filters.included_domains = include[CONF_DOMAINS] - recorder.get_instance() hass.http.register_view(HistoryPeriodView(filters)) register_built_in_panel(hass, 'history', 'History', 'mdi:poll-box') @@ -223,8 +234,8 @@ class HistoryPeriodView(HomeAssistantView): entity_id = request.GET.get('filter_entity_id') result = yield from request.app['hass'].loop.run_in_executor( - None, get_significant_states, start_time, end_time, entity_id, - self.filters) + None, get_significant_states, request.app['hass'], start_time, + end_time, entity_id, self.filters) result = result.values() if _LOGGER.isEnabledFor(logging.DEBUG): elapsed = time.perf_counter() - timer_start @@ -254,41 +265,42 @@ class Filters(object): * if include and exclude is defined - select the entities specified in the include and filter out the ones from the exclude list. """ - states = recorder.get_model('States') + from homeassistant.components.recorder.models import States + # specific entities requested - do not in/exclude anything if entity_ids is not None: - return query.filter(states.entity_id.in_(entity_ids)) - query = query.filter(~states.domain.in_(IGNORE_DOMAINS)) + return query.filter(States.entity_id.in_(entity_ids)) + query = query.filter(~States.domain.in_(IGNORE_DOMAINS)) filter_query = None # filter if only excluded domain is configured if self.excluded_domains and not self.included_domains: - filter_query = ~states.domain.in_(self.excluded_domains) + filter_query = ~States.domain.in_(self.excluded_domains) if self.included_entities: - filter_query &= states.entity_id.in_(self.included_entities) + filter_query &= States.entity_id.in_(self.included_entities) # filter if only included domain is configured elif not self.excluded_domains and self.included_domains: - filter_query = states.domain.in_(self.included_domains) + filter_query = States.domain.in_(self.included_domains) if self.included_entities: - filter_query |= states.entity_id.in_(self.included_entities) + filter_query |= States.entity_id.in_(self.included_entities) # filter if included and excluded domain is configured elif self.excluded_domains and self.included_domains: - filter_query = ~states.domain.in_(self.excluded_domains) + filter_query = ~States.domain.in_(self.excluded_domains) if self.included_entities: - filter_query &= (states.domain.in_(self.included_domains) | - states.entity_id.in_(self.included_entities)) + filter_query &= (States.domain.in_(self.included_domains) | + States.entity_id.in_(self.included_entities)) else: - filter_query &= (states.domain.in_(self.included_domains) & ~ - states.domain.in_(self.excluded_domains)) + filter_query &= (States.domain.in_(self.included_domains) & ~ + States.domain.in_(self.excluded_domains)) # no domain filter just included entities elif not self.excluded_domains and not self.included_domains and \ self.included_entities: - filter_query = states.entity_id.in_(self.included_entities) + filter_query = States.entity_id.in_(self.included_entities) if filter_query is not None: query = query.filter(filter_query) # finally apply excluded entities filter if configured if self.excluded_entities: - query = query.filter(~states.entity_id.in_(self.excluded_entities)) + query = query.filter(~States.entity_id.in_(self.excluded_entities)) return query diff --git a/homeassistant/components/logbook.py b/homeassistant/components/logbook.py index 30d52303099..92f99887867 100644 --- a/homeassistant/components/logbook.py +++ b/homeassistant/components/logbook.py @@ -14,7 +14,7 @@ import voluptuous as vol from homeassistant.core import callback import homeassistant.helpers.config_validation as cv import homeassistant.util.dt as dt_util -from homeassistant.components import recorder, sun +from homeassistant.components import sun from homeassistant.components.frontend import register_built_in_panel from homeassistant.components.http import HomeAssistantView from homeassistant.const import (EVENT_HOMEASSISTANT_START, @@ -98,7 +98,7 @@ def setup(hass, config): message = message.async_render() async_log_entry(hass, name, message, domain, entity_id) - hass.http.register_view(LogbookView(config)) + hass.http.register_view(LogbookView(config.get(DOMAIN, {}))) register_built_in_panel(hass, 'logbook', 'Logbook', 'mdi:format-list-bulleted-type') @@ -132,20 +132,11 @@ class LogbookView(HomeAssistantView): start_day = dt_util.as_utc(datetime) end_day = start_day + timedelta(days=1) + hass = request.app['hass'] - def get_results(): - """Query DB for results.""" - events = recorder.get_model('Events') - query = recorder.query('Events').order_by( - events.time_fired).filter( - (events.time_fired > start_day) & - (events.time_fired < end_day)) - events = recorder.execute(query) - return _exclude_events(events, self.config) - - events = yield from request.app['hass'].loop.run_in_executor( - None, get_results) - + events = yield from hass.loop.run_in_executor( + None, _get_events, hass, start_day, end_day) + events = _exclude_events(events, self.config) return self.json(humanify(events)) @@ -282,17 +273,31 @@ def humanify(events): entity_id) +def _get_events(hass, start_day, end_day): + """Get events for a period of time.""" + from homeassistant.components.recorder.models import Events + from homeassistant.components.recorder.util import ( + execute, session_scope) + + with session_scope(hass=hass) as session: + query = session.query(Events).order_by( + Events.time_fired).filter( + (Events.time_fired > start_day) & + (Events.time_fired < end_day)) + return execute(query) + + def _exclude_events(events, config): """Get lists of excluded entities and platforms.""" excluded_entities = [] excluded_domains = [] included_entities = [] included_domains = [] - exclude = config[DOMAIN].get(CONF_EXCLUDE) + exclude = config.get(CONF_EXCLUDE) if exclude: excluded_entities = exclude[CONF_ENTITIES] excluded_domains = exclude[CONF_DOMAINS] - include = config[DOMAIN].get(CONF_INCLUDE) + include = config.get(CONF_INCLUDE) if include: included_entities = include[CONF_ENTITIES] included_domains = include[CONF_DOMAINS] diff --git a/homeassistant/components/recorder/__init__.py b/homeassistant/components/recorder/__init__.py index 0f8d7b48fe2..c60b95d1cae 100644 --- a/homeassistant/components/recorder/__init__.py +++ b/homeassistant/components/recorder/__init__.py @@ -8,27 +8,31 @@ For more details about this component, please refer to the documentation at https://home-assistant.io/components/recorder/ """ import asyncio +import concurrent.futures import logging import queue import threading import time from datetime import timedelta, datetime -from typing import Any, Union, Optional, List, Dict -from contextlib import contextmanager +from typing import Optional, Dict import voluptuous as vol -from homeassistant.core import HomeAssistant, callback, split_entity_id +from homeassistant.core import ( + HomeAssistant, callback, split_entity_id, CoreState) from homeassistant.const import ( ATTR_ENTITY_ID, CONF_ENTITIES, CONF_EXCLUDE, CONF_DOMAINS, - CONF_INCLUDE, EVENT_HOMEASSISTANT_STOP, + CONF_INCLUDE, EVENT_HOMEASSISTANT_STOP, EVENT_HOMEASSISTANT_START, EVENT_STATE_CHANGED, EVENT_TIME_CHANGED, MATCH_ALL) -from homeassistant.exceptions import HomeAssistantError import homeassistant.helpers.config_validation as cv from homeassistant.helpers.event import async_track_time_interval -from homeassistant.helpers.typing import ConfigType, QueryType +from homeassistant.helpers.typing import ConfigType import homeassistant.util.dt as dt_util +from . import purge, migration +from .const import DATA_INSTANCE +from .util import session_scope + DOMAIN = 'recorder' REQUIREMENTS = ['sqlalchemy==1.1.5'] @@ -39,9 +43,7 @@ DEFAULT_DB_FILE = 'home-assistant_v2.db' CONF_DB_URL = 'db_url' CONF_PURGE_DAYS = 'purge_days' -RETRIES = 3 CONNECT_RETRY_WAIT = 10 -QUERY_RETRY_WAIT = 0.1 ERROR_QUERY = "Error during query: %s" FILTER_SCHEMA = vol.Schema({ @@ -65,88 +67,32 @@ CONFIG_SCHEMA = vol.Schema({ }) }, extra=vol.ALLOW_EXTRA) -_INSTANCE = None # type: Any _LOGGER = logging.getLogger(__name__) -@contextmanager -def session_scope(): - """Provide a transactional scope around a series of operations.""" - session = _INSTANCE.get_session() - try: - yield session - session.commit() - except Exception as err: # pylint: disable=broad-except - _LOGGER.error(ERROR_QUERY, err) - session.rollback() - raise - finally: - session.close() - - -@asyncio.coroutine -def async_get_instance(): - """Throw error if recorder not initialized.""" - if _INSTANCE is None: - raise RuntimeError("Recorder not initialized.") - - yield from _INSTANCE.async_db_ready.wait() - - return _INSTANCE - - -def get_instance(): - """Throw error if recorder not initialized.""" - if _INSTANCE is None: - raise RuntimeError("Recorder not initialized.") - - ident = _INSTANCE.hass.loop.__dict__.get("_thread_ident") - if ident is not None and ident == threading.get_ident(): - raise RuntimeError('Cannot be called from within the event loop') - - _wait(_INSTANCE.db_ready, "Database not ready") - - return _INSTANCE - - -# pylint: disable=invalid-sequence-index -def execute(qry: QueryType) -> List[Any]: - """Query the database and convert the objects to HA native form. - - This method also retries a few times in the case of stale connections. +def wait_connection_ready(hass): """ - get_instance() - from sqlalchemy.exc import SQLAlchemyError - with session_scope() as session: - for _ in range(0, RETRIES): - try: - return [ - row for row in - (row.to_native() for row in qry) - if row is not None] - except SQLAlchemyError as err: - _LOGGER.error(ERROR_QUERY, err) - session.rollback() - time.sleep(QUERY_RETRY_WAIT) - return [] + Wait till the connection is ready. + + Returns a coroutine object. + """ + return hass.data[DATA_INSTANCE].async_db_ready.wait() -def run_information(point_in_time: Optional[datetime]=None): +def run_information(hass, point_in_time: Optional[datetime]=None): """Return information about current run. There is also the run that covers point_in_time. """ - ins = get_instance() + from . import models + ins = hass.data[DATA_INSTANCE] - recorder_runs = get_model('RecorderRuns') + recorder_runs = models.RecorderRuns if point_in_time is None or point_in_time > ins.recording_start: - return recorder_runs( - end=None, - start=ins.recording_start, - closed_incorrect=False) + return ins.run_info - with session_scope() as session: - res = query(recorder_runs).filter( + with session_scope(hass=hass) as session: + res = session.query(recorder_runs).filter( (recorder_runs.start < point_in_time) & (recorder_runs.end > point_in_time)).first() if res: @@ -154,88 +100,67 @@ def run_information(point_in_time: Optional[datetime]=None): return res -def setup(hass: HomeAssistant, config: ConfigType) -> bool: +@asyncio.coroutine +def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Setup the recorder.""" - global _INSTANCE # pylint: disable=global-statement + conf = config.get(DOMAIN, {}) + purge_days = conf.get(CONF_PURGE_DAYS) - if _INSTANCE is not None: - _LOGGER.error("Only a single instance allowed") - return False - - purge_days = config.get(DOMAIN, {}).get(CONF_PURGE_DAYS) - - db_url = config.get(DOMAIN, {}).get(CONF_DB_URL, None) + db_url = conf.get(CONF_DB_URL, None) if not db_url: db_url = DEFAULT_URL.format( hass_config_path=hass.config.path(DEFAULT_DB_FILE)) - include = config.get(DOMAIN, {}).get(CONF_INCLUDE, {}) - exclude = config.get(DOMAIN, {}).get(CONF_EXCLUDE, {}) - _INSTANCE = Recorder(hass, purge_days=purge_days, uri=db_url, - include=include, exclude=exclude) - _INSTANCE.start() + include = conf.get(CONF_INCLUDE, {}) + exclude = conf.get(CONF_EXCLUDE, {}) + hass.data[DATA_INSTANCE] = Recorder( + hass, purge_days=purge_days, uri=db_url, include=include, + exclude=exclude) + hass.data[DATA_INSTANCE].async_initialize() + hass.data[DATA_INSTANCE].start() return True -def query(model_name: Union[str, Any], session=None, *args) -> QueryType: - """Helper to return a query handle.""" - if session is None: - session = get_instance().get_session() - - if isinstance(model_name, str): - return session.query(get_model(model_name), *args) - return session.query(model_name, *args) - - -def get_model(model_name: str) -> Any: - """Get a model class.""" - from homeassistant.components.recorder import models - try: - return getattr(models, model_name) - except AttributeError: - _LOGGER.error("Invalid model name %s", model_name) - return None - - class Recorder(threading.Thread): """A threaded recorder class.""" def __init__(self, hass: HomeAssistant, purge_days: int, uri: str, include: Dict, exclude: Dict) -> None: """Initialize the recorder.""" - threading.Thread.__init__(self) + threading.Thread.__init__(self, name='Recorder') self.hass = hass self.purge_days = purge_days self.queue = queue.Queue() # type: Any self.recording_start = dt_util.utcnow() self.db_url = uri - self.db_ready = threading.Event() self.async_db_ready = asyncio.Event(loop=hass.loop) self.engine = None # type: Any - self._run = None # type: Any + self.run_info = None # type: Any self.include_e = include.get(CONF_ENTITIES, []) self.include_d = include.get(CONF_DOMAINS, []) self.exclude = exclude.get(CONF_ENTITIES, []) + \ exclude.get(CONF_DOMAINS, []) - hass.bus.listen_once(EVENT_HOMEASSISTANT_STOP, self.shutdown) - hass.bus.listen(MATCH_ALL, self.event_listener) - self.get_session = None + @callback + def async_initialize(self): + """Initialize the recorder.""" + self.hass.bus.async_listen(MATCH_ALL, self.event_listener) + def run(self): """Start processing events to save.""" - from homeassistant.components.recorder.models import Events, States from sqlalchemy.exc import SQLAlchemyError + from .models import States, Events while True: try: self._setup_connection() + migration.migrate_schema(self) self._setup_run() - self.db_ready.set() self.hass.loop.call_soon_threadsafe(self.async_db_ready.set) break except SQLAlchemyError as err: @@ -243,9 +168,49 @@ class Recorder(threading.Thread): "in %s seconds)", err, CONNECT_RETRY_WAIT) time.sleep(CONNECT_RETRY_WAIT) - if self.purge_days is not None: - async_track_time_interval( - self.hass, self._purge_old_data, timedelta(days=2)) + purge_task = object() + shutdown_task = object() + hass_started = concurrent.futures.Future() + + @callback + def register(): + """Post connection initialize.""" + def shutdown(event): + """Shut down the Recorder.""" + if not hass_started.done(): + hass_started.set_result(shutdown_task) + self.queue.put(None) + self.join() + + self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, + shutdown) + + if self.hass.state == CoreState.running: + hass_started.set_result(None) + else: + @callback + def notify_hass_started(event): + """Notify that hass has started.""" + hass_started.set_result(None) + + self.hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, + notify_hass_started) + + if self.purge_days is not None: + @callback + def do_purge(now): + """Event listener for purging data.""" + self.queue.put(purge_task) + + async_track_time_interval(self.hass, do_purge, + timedelta(days=2)) + + self.hass.add_job(register) + result = hass_started.result() + + # If shutdown happened before HASS finished starting + if result is shutdown_task: + return while True: event = self.queue.get() @@ -255,8 +220,10 @@ class Recorder(threading.Thread): self._close_connection() self.queue.task_done() return - - if event.event_type == EVENT_TIME_CHANGED: + elif event is purge_task: + purge.purge_old_data(self, self.purge_days) + continue + elif event.event_type == EVENT_TIME_CHANGED: self.queue.task_done() continue @@ -280,17 +247,14 @@ class Recorder(threading.Thread): self.queue.task_done() continue - with session_scope() as session: + with session_scope(session=self.get_session()) as session: dbevent = Events.from_event(event) - self._commit(session, dbevent) + session.add(dbevent) - if event.event_type != EVENT_STATE_CHANGED: - self.queue.task_done() - continue - - dbstate = States.from_event(event) - dbstate.event_id = dbevent.event_id - self._commit(session, dbstate) + if event.event_type == EVENT_STATE_CHANGED: + dbstate = States.from_event(event) + dbstate.event_id = dbevent.event_id + session.add(dbstate) self.queue.task_done() @@ -299,27 +263,16 @@ class Recorder(threading.Thread): """Listen for new events and put them in the process queue.""" self.queue.put(event) - def shutdown(self, event): - """Tell the recorder to shut down.""" - global _INSTANCE # pylint: disable=global-statement - self.queue.put(None) - self.join() - _INSTANCE = None - def block_till_done(self): """Block till all events processed.""" self.queue.join() - def block_till_db_ready(self): - """Block until the database session is ready.""" - _wait(self.db_ready, "Database not ready") - def _setup_connection(self): """Ensure database is ready to fly.""" - import homeassistant.components.recorder.models as models from sqlalchemy import create_engine from sqlalchemy.orm import scoped_session from sqlalchemy.orm import sessionmaker + from . import models if self.db_url == 'sqlite://' or ':memory:' in self.db_url: from sqlalchemy.pool import StaticPool @@ -334,85 +287,6 @@ class Recorder(threading.Thread): models.Base.metadata.create_all(self.engine) session_factory = sessionmaker(bind=self.engine) self.get_session = scoped_session(session_factory) - self._migrate_schema() - - def _migrate_schema(self): - """Check if the schema needs to be upgraded.""" - from homeassistant.components.recorder.models import SCHEMA_VERSION - schema_changes = get_model('SchemaChanges') - with session_scope() as session: - res = session.query(schema_changes).order_by( - schema_changes.change_id.desc()).first() - current_version = getattr(res, 'schema_version', None) - - if current_version == SCHEMA_VERSION: - return - _LOGGER.debug("Schema version incorrect: %s", current_version) - - if current_version is None: - current_version = self._inspect_schema_version() - _LOGGER.debug("No schema version found. Inspected version: %s", - current_version) - - for version in range(current_version, SCHEMA_VERSION): - new_version = version + 1 - _LOGGER.info("Upgrading recorder db schema to version %s", - new_version) - self._apply_update(new_version) - self._commit(session, - schema_changes(schema_version=new_version)) - _LOGGER.info("Upgraded recorder db schema to version %s", - new_version) - - def _apply_update(self, new_version): - """Perform operations to bring schema up to date.""" - from sqlalchemy import Table - import homeassistant.components.recorder.models as models - - if new_version == 1: - def create_index(table_name, column_name): - """Create an index for the specified table and column.""" - table = Table(table_name, models.Base.metadata) - name = "_".join(("ix", table_name, column_name)) - # Look up the index object that was created from the models - index = next(idx for idx in table.indexes if idx.name == name) - _LOGGER.debug("Creating index for table %s column %s", - table_name, column_name) - index.create(self.engine) - _LOGGER.debug("Index creation done for table %s column %s", - table_name, column_name) - - create_index("events", "time_fired") - else: - raise ValueError("No schema migration defined for version {}" - .format(new_version)) - - def _inspect_schema_version(self): - """Determine the schema version by inspecting the db structure. - - When the schema verison is not present in the db, either db was just - created with the correct schema, or this is a db created before schema - versions were tracked. For now, we'll test if the changes for schema - version 1 are present to make the determination. Eventually this logic - can be removed and we can assume a new db is being created. - """ - from sqlalchemy.engine import reflection - import homeassistant.components.recorder.models as models - inspector = reflection.Inspector.from_engine(self.engine) - indexes = inspector.get_indexes("events") - with session_scope() as session: - for index in indexes: - if index['column_names'] == ["time_fired"]: - # Schema addition from version 1 detected. New DB. - current_version = models.SchemaChanges( - schema_version=models.SCHEMA_VERSION) - self._commit(session, current_version) - return models.SCHEMA_VERSION - - # Version 1 schema changes not found, this db needs to be migrated. - current_version = models.SchemaChanges(schema_version=0) - self._commit(session, current_version) - return current_version.schema_version def _close_connection(self): """Close the connection.""" @@ -422,93 +296,27 @@ class Recorder(threading.Thread): def _setup_run(self): """Log the start of the current run.""" - recorder_runs = get_model('RecorderRuns') - with session_scope() as session: - for run in query( - recorder_runs, session=session).filter_by(end=None): + from .models import RecorderRuns + + with session_scope(session=self.get_session()) as session: + for run in session.query(RecorderRuns).filter_by(end=None): run.closed_incorrect = True run.end = self.recording_start _LOGGER.warning("Ended unfinished session (id=%s from %s)", run.run_id, run.start) session.add(run) - _LOGGER.warning("Found unfinished sessions") - - self._run = recorder_runs( + self.run_info = RecorderRuns( start=self.recording_start, created=dt_util.utcnow() ) - self._commit(session, self._run) + session.add(self.run_info) + session.flush() + session.expunge(self.run_info) def _close_run(self): """Save end time for current run.""" - with session_scope() as session: - self._run.end = dt_util.utcnow() - self._commit(session, self._run) - self._run = None - - def _purge_old_data(self, _=None): - """Purge events and states older than purge_days ago.""" - from homeassistant.components.recorder.models import Events, States - - if not self.purge_days or self.purge_days < 1: - _LOGGER.debug("purge_days set to %s, will not purge any old data.", - self.purge_days) - return - - purge_before = dt_util.utcnow() - timedelta(days=self.purge_days) - - def _purge_states(session): - deleted_rows = session.query(States) \ - .filter((States.created < purge_before)) \ - .delete(synchronize_session=False) - _LOGGER.debug("Deleted %s states", deleted_rows) - - with session_scope() as session: - if self._commit(session, _purge_states): - _LOGGER.info("Purged states created before %s", purge_before) - - def _purge_events(session): - deleted_rows = session.query(Events) \ - .filter((Events.created < purge_before)) \ - .delete(synchronize_session=False) - _LOGGER.debug("Deleted %s events", deleted_rows) - - with session_scope() as session: - if self._commit(session, _purge_events): - _LOGGER.info("Purged events created before %s", purge_before) - - # Execute sqlite vacuum command to free up space on disk - if self.engine.driver == 'sqlite': - _LOGGER.info("Vacuuming SQLite to free space") - self.engine.execute("VACUUM") - - @staticmethod - def _commit(session, work): - """Commit & retry work: Either a model or in a function.""" - import sqlalchemy.exc - for _ in range(0, RETRIES): - try: - if callable(work): - work(session) - else: - session.add(work) - session.commit() - return True - except sqlalchemy.exc.OperationalError as err: - _LOGGER.error(ERROR_QUERY, err) - session.rollback() - time.sleep(QUERY_RETRY_WAIT) - return False - - -def _wait(event, message): - """Event wait helper.""" - for retry in (10, 20, 30): - event.wait(10) - if event.is_set(): - return - msg = "{} ({} seconds)".format(message, retry) - _LOGGER.warning(msg) - if not event.is_set(): - raise HomeAssistantError(msg) + with session_scope(session=self.get_session()) as session: + self.run_info.end = dt_util.utcnow() + session.add(self.run_info) + self.run_info = None diff --git a/homeassistant/components/recorder/const.py b/homeassistant/components/recorder/const.py new file mode 100644 index 00000000000..e2716ea982a --- /dev/null +++ b/homeassistant/components/recorder/const.py @@ -0,0 +1,3 @@ +"""Recorder constants.""" + +DATA_INSTANCE = 'recorder_instance' diff --git a/homeassistant/components/recorder/migration.py b/homeassistant/components/recorder/migration.py new file mode 100644 index 00000000000..09c5e9837c3 --- /dev/null +++ b/homeassistant/components/recorder/migration.py @@ -0,0 +1,88 @@ +"""Schema migration helpers.""" +import logging + +from .util import session_scope + +_LOGGER = logging.getLogger(__name__) + + +def migrate_schema(instance): + """Check if the schema needs to be upgraded.""" + from .models import SchemaChanges, SCHEMA_VERSION + + with session_scope(session=instance.get_session()) as session: + res = session.query(SchemaChanges).order_by( + SchemaChanges.change_id.desc()).first() + current_version = getattr(res, 'schema_version', None) + + if current_version == SCHEMA_VERSION: + return + + _LOGGER.debug("Database requires upgrade. Schema version: %s", + current_version) + + if current_version is None: + current_version = _inspect_schema_version(instance.engine, session) + _LOGGER.debug("No schema version found. Inspected version: %s", + current_version) + + for version in range(current_version, SCHEMA_VERSION): + new_version = version + 1 + _LOGGER.info("Upgrading recorder db schema to version %s", + new_version) + _apply_update(instance.engine, new_version) + session.add(SchemaChanges(schema_version=new_version)) + + _LOGGER.info("Upgrade to version %s done", new_version) + + +def _apply_update(engine, new_version): + """Perform operations to bring schema up to date.""" + from sqlalchemy import Table + from . import models + + if new_version == 1: + def create_index(table_name, column_name): + """Create an index for the specified table and column.""" + table = Table(table_name, models.Base.metadata) + name = "_".join(("ix", table_name, column_name)) + # Look up the index object that was created from the models + index = next(idx for idx in table.indexes if idx.name == name) + _LOGGER.debug("Creating index for table %s column %s", + table_name, column_name) + index.create(engine) + _LOGGER.debug("Index creation done for table %s column %s", + table_name, column_name) + + create_index("events", "time_fired") + else: + raise ValueError("No schema migration defined for version {}" + .format(new_version)) + + +def _inspect_schema_version(engine, session): + """Determine the schema version by inspecting the db structure. + + When the schema verison is not present in the db, either db was just + created with the correct schema, or this is a db created before schema + versions were tracked. For now, we'll test if the changes for schema + version 1 are present to make the determination. Eventually this logic + can be removed and we can assume a new db is being created. + """ + from sqlalchemy.engine import reflection + from .models import SchemaChanges, SCHEMA_VERSION + + inspector = reflection.Inspector.from_engine(engine) + indexes = inspector.get_indexes("events") + + for index in indexes: + if index['column_names'] == ["time_fired"]: + # Schema addition from version 1 detected. New DB. + session.add(SchemaChanges( + schema_version=SCHEMA_VERSION)) + return SCHEMA_VERSION + + # Version 1 schema changes not found, this db needs to be migrated. + current_version = SchemaChanges(schema_version=0) + session.add(current_version) + return current_version.schema_version diff --git a/homeassistant/components/recorder/purge.py b/homeassistant/components/recorder/purge.py new file mode 100644 index 00000000000..2b675e72759 --- /dev/null +++ b/homeassistant/components/recorder/purge.py @@ -0,0 +1,31 @@ +"""Purge old data helper.""" +from datetime import timedelta +import logging + +import homeassistant.util.dt as dt_util + +from .util import session_scope + +_LOGGER = logging.getLogger(__name__) + + +def purge_old_data(instance, purge_days): + """Purge events and states older than purge_days ago.""" + from .models import States, Events + purge_before = dt_util.utcnow() - timedelta(days=purge_days) + + with session_scope(session=instance.get_session()) as session: + deleted_rows = session.query(States) \ + .filter((States.created < purge_before)) \ + .delete(synchronize_session=False) + _LOGGER.debug("Deleted %s states", deleted_rows) + + deleted_rows = session.query(Events) \ + .filter((Events.created < purge_before)) \ + .delete(synchronize_session=False) + _LOGGER.debug("Deleted %s events", deleted_rows) + + # Execute sqlite vacuum command to free up space on disk + if instance.engine.driver == 'sqlite': + _LOGGER.info("Vacuuming SQLite to free space") + instance.engine.execute("VACUUM") diff --git a/homeassistant/components/recorder/util.py b/homeassistant/components/recorder/util.py new file mode 100644 index 00000000000..e4ea1af1060 --- /dev/null +++ b/homeassistant/components/recorder/util.py @@ -0,0 +1,71 @@ +"""SQLAlchemy util functions.""" +from contextlib import contextmanager +import logging +import time + +from .const import DATA_INSTANCE + +_LOGGER = logging.getLogger(__name__) + +RETRIES = 3 +QUERY_RETRY_WAIT = 0.1 + + +@contextmanager +def session_scope(*, hass=None, session=None): + """Provide a transactional scope around a series of operations.""" + if session is None and hass is not None: + session = hass.data[DATA_INSTANCE].get_session() + + if session is None: + raise RuntimeError('Session required') + + try: + yield session + session.commit() + except Exception as err: # pylint: disable=broad-except + _LOGGER.error('Error executing query: %s', err) + session.rollback() + raise + finally: + session.close() + + +def commit(session, work): + """Commit & retry work: Either a model or in a function.""" + import sqlalchemy.exc + for _ in range(0, RETRIES): + try: + if callable(work): + work(session) + else: + session.add(work) + session.commit() + return True + except sqlalchemy.exc.OperationalError as err: + _LOGGER.error('Error executing query: %s', err) + session.rollback() + time.sleep(QUERY_RETRY_WAIT) + return False + + +def execute(qry): + """Query the database and convert the objects to HA native form. + + This method also retries a few times in the case of stale connections. + """ + from sqlalchemy.exc import SQLAlchemyError + + for tryno in range(0, RETRIES): + try: + return [ + row for row in + (row.to_native() for row in qry) + if row is not None] + except SQLAlchemyError as err: + _LOGGER.error('Error executing query: %s', err) + + if tryno == RETRIES - 1: + raise + else: + time.sleep(QUERY_RETRY_WAIT) diff --git a/homeassistant/components/sensor/history_stats.py b/homeassistant/components/sensor/history_stats.py index b019e6745fb..eb54869d66f 100644 --- a/homeassistant/components/sensor/history_stats.py +++ b/homeassistant/components/sensor/history_stats.py @@ -164,13 +164,13 @@ class HistoryStatsSensor(Entity): # Get history between start and end history_list = history.state_changes_during_period( - start, end, str(self._entity_id)) + self.hass, start, end, str(self._entity_id)) if self._entity_id not in history_list.keys(): return # Get the first state - last_state = history.get_state(start, self._entity_id) + last_state = history.get_state(self.hass, start, self._entity_id) last_state = (last_state is not None and last_state == self._entity_state) last_time = dt_util.as_timestamp(start) diff --git a/homeassistant/helpers/restore_state.py b/homeassistant/helpers/restore_state.py index 86cd3e7037f..4ac1e442546 100644 --- a/homeassistant/helpers/restore_state.py +++ b/homeassistant/helpers/restore_state.py @@ -7,7 +7,7 @@ from homeassistant.core import HomeAssistant, CoreState, callback from homeassistant.const import EVENT_HOMEASSISTANT_START from homeassistant.components.history import get_states, last_recorder_run from homeassistant.components.recorder import ( - async_get_instance, DOMAIN as _RECORDER) + wait_connection_ready, DOMAIN as _RECORDER) import homeassistant.util.dt as dt_util _LOGGER = logging.getLogger(__name__) @@ -25,7 +25,7 @@ def _load_restore_cache(hass: HomeAssistant): hass.bus.listen_once(EVENT_HOMEASSISTANT_START, remove_cache) - last_run = last_recorder_run() + last_run = last_recorder_run(hass) if last_run is None or last_run.end is None: _LOGGER.debug('Not creating cache - no suitable last run found: %s', @@ -38,7 +38,7 @@ def _load_restore_cache(hass: HomeAssistant): last_end_time = last_end_time.replace(tzinfo=dt_util.UTC) _LOGGER.debug("Last run: %s - %s", last_run.start, last_end_time) - states = get_states(last_end_time, run=last_run) + states = get_states(hass, last_end_time, run=last_run) # Cache the states hass.data[DATA_RESTORE_CACHE] = { @@ -58,7 +58,7 @@ def async_get_last_state(hass, entity_id: str): hass.state) return None - yield from async_get_instance() # Ensure recorder ready + yield from wait_connection_ready(hass) if _LOCK not in hass.data: hass.data[_LOCK] = asyncio.Lock(loop=hass.loop) diff --git a/tests/common.py b/tests/common.py index 93ddc7c2f65..55d6896d410 100644 --- a/tests/common.py +++ b/tests/common.py @@ -28,7 +28,8 @@ from homeassistant.components import sun, mqtt, recorder from homeassistant.components.http.auth import auth_middleware from homeassistant.components.http.const import ( KEY_USE_X_FORWARDED_FOR, KEY_BANS_ENABLED, KEY_TRUSTED_NETWORKS) -from homeassistant.util.async import run_callback_threadsafe +from homeassistant.util.async import ( + run_callback_threadsafe, run_coroutine_threadsafe) _TEST_INSTANCE_PORT = SERVER_PORT _LOGGER = logging.getLogger(__name__) @@ -464,15 +465,17 @@ def assert_setup_component(count, domain=None): .format(count, res_len, res) -def init_recorder_component(hass, add_config=None, db_ready_callback=None): +def init_recorder_component(hass, add_config=None): """Initialize the recorder.""" config = dict(add_config) if add_config else {} config[recorder.CONF_DB_URL] = 'sqlite://' # In memory DB - assert setup_component(hass, recorder.DOMAIN, - {recorder.DOMAIN: config}) - assert recorder.DOMAIN in hass.config.components - recorder.get_instance().block_till_db_ready() + with patch('homeassistant.components.recorder.migration.migrate_schema'): + assert setup_component(hass, recorder.DOMAIN, + {recorder.DOMAIN: config}) + assert recorder.DOMAIN in hass.config.components + run_coroutine_threadsafe( + recorder.wait_connection_ready(hass), hass.loop).result() _LOGGER.info("In-memory recorder successfully started") diff --git a/tests/components/recorder/test_init.py b/tests/components/recorder/test_init.py index fa38a9d3784..0724313dcea 100644 --- a/tests/components/recorder/test_init.py +++ b/tests/components/recorder/test_init.py @@ -1,94 +1,29 @@ """The tests for the Recorder component.""" # pylint: disable=protected-access -import json -from datetime import datetime, timedelta import unittest -from unittest.mock import patch, call, MagicMock import pytest -from sqlalchemy import create_engine from homeassistant.core import callback from homeassistant.const import MATCH_ALL -from homeassistant.components import recorder +from homeassistant.components.recorder.const import DATA_INSTANCE +from homeassistant.components.recorder.util import session_scope +from homeassistant.components.recorder.models import States, Events from tests.common import get_test_home_assistant, init_recorder_component -from tests.components.recorder import models_original -class BaseTestRecorder(unittest.TestCase): - """Base class for common recorder tests.""" +class TestRecorder(unittest.TestCase): + """Test the recorder module.""" def setUp(self): # pylint: disable=invalid-name """Setup things to be run when tests are started.""" self.hass = get_test_home_assistant() init_recorder_component(self.hass) self.hass.start() - recorder.get_instance().block_till_done() def tearDown(self): # pylint: disable=invalid-name """Stop everything that was started.""" self.hass.stop() - with self.assertRaises(RuntimeError): - recorder.get_instance() - - def _add_test_states(self): - """Add multiple states to the db for testing.""" - now = datetime.now() - five_days_ago = now - timedelta(days=5) - attributes = {'test_attr': 5, 'test_attr_10': 'nice'} - - self.hass.block_till_done() - recorder._INSTANCE.block_till_done() - - with recorder.session_scope() as session: - for event_id in range(5): - if event_id < 3: - timestamp = five_days_ago - state = 'purgeme' - else: - timestamp = now - state = 'dontpurgeme' - - session.add(recorder.get_model('States')( - entity_id='test.recorder2', - domain='sensor', - state=state, - attributes=json.dumps(attributes), - last_changed=timestamp, - last_updated=timestamp, - created=timestamp, - event_id=event_id + 1000 - )) - - def _add_test_events(self): - """Add a few events for testing.""" - now = datetime.now() - five_days_ago = now - timedelta(days=5) - event_data = {'test_attr': 5, 'test_attr_10': 'nice'} - - self.hass.block_till_done() - recorder._INSTANCE.block_till_done() - - with recorder.session_scope() as session: - for event_id in range(5): - if event_id < 2: - timestamp = five_days_ago - event_type = 'EVENT_TEST_PURGE' - else: - timestamp = now - event_type = 'EVENT_TEST' - - session.add(recorder.get_model('Events')( - event_type=event_type, - event_data=json.dumps(event_data), - origin='LOCAL', - created=timestamp, - time_fired=timestamp, - )) - - -class TestRecorder(BaseTestRecorder): - """Test the recorder module.""" def test_saving_state(self): """Test saving and restoring a state.""" @@ -99,15 +34,14 @@ class TestRecorder(BaseTestRecorder): self.hass.states.set(entity_id, state, attributes) self.hass.block_till_done() - recorder._INSTANCE.block_till_done() + self.hass.data[DATA_INSTANCE].block_till_done() - db_states = recorder.query('States') - states = recorder.execute(db_states) + with session_scope(hass=self.hass) as session: + db_states = list(session.query(States)) + assert len(db_states) == 1 + state = db_states[0].to_native() - assert db_states[0].event_id is not None - - self.assertEqual(1, len(states)) - self.assertEqual(self.hass.states.get(entity_id), states[0]) + assert state == self.hass.states.get(entity_id) def test_saving_event(self): """Test saving and restoring an event.""" @@ -127,17 +61,17 @@ class TestRecorder(BaseTestRecorder): self.hass.bus.fire(event_type, event_data) self.hass.block_till_done() - recorder._INSTANCE.block_till_done() - - db_events = recorder.execute( - recorder.query('Events').filter_by( - event_type=event_type)) assert len(events) == 1 - assert len(db_events) == 1 - event = events[0] - db_event = db_events[0] + + self.hass.data[DATA_INSTANCE].block_till_done() + + with session_scope(hass=self.hass) as session: + db_events = list(session.query(Events).filter_by( + event_type=event_type)) + assert len(db_events) == 1 + db_event = db_events[0].to_native() assert event.event_type == db_event.event_type assert event.data == db_event.data @@ -147,110 +81,6 @@ class TestRecorder(BaseTestRecorder): assert event.time_fired.replace(microsecond=0) == \ db_event.time_fired.replace(microsecond=0) - def test_purge_old_states(self): - """Test deleting old states.""" - self._add_test_states() - # make sure we start with 5 states - states = recorder.query('States') - self.assertEqual(states.count(), 5) - - # run purge_old_data() - recorder._INSTANCE.purge_days = 4 - recorder._INSTANCE._purge_old_data() - - # we should only have 2 states left after purging - self.assertEqual(states.count(), 2) - - def test_purge_old_events(self): - """Test deleting old events.""" - self._add_test_events() - events = recorder.query('Events').filter( - recorder.get_model('Events').event_type.like("EVENT_TEST%")) - self.assertEqual(events.count(), 5) - - # run purge_old_data() - recorder._INSTANCE.purge_days = 4 - recorder._INSTANCE._purge_old_data() - - # now we should only have 3 events left - self.assertEqual(events.count(), 3) - - def test_purge_disabled(self): - """Test leaving purge_days disabled.""" - self._add_test_states() - self._add_test_events() - # make sure we start with 5 states and events - states = recorder.query('States') - events = recorder.query('Events').filter( - recorder.get_model('Events').event_type.like("EVENT_TEST%")) - self.assertEqual(states.count(), 5) - self.assertEqual(events.count(), 5) - - # run purge_old_data() - recorder._INSTANCE.purge_days = None - recorder._INSTANCE._purge_old_data() - - # we should have all of our states still - self.assertEqual(states.count(), 5) - self.assertEqual(events.count(), 5) - - def test_schema_no_recheck(self): - """Test that schema is not double-checked when up-to-date.""" - with patch.object(recorder._INSTANCE, '_apply_update') as update, \ - patch.object(recorder._INSTANCE, '_inspect_schema_version') \ - as inspect: - recorder._INSTANCE._migrate_schema() - self.assertEqual(update.call_count, 0) - self.assertEqual(inspect.call_count, 0) - - def test_invalid_update(self): - """Test that an invalid new version raises an exception.""" - with self.assertRaises(ValueError): - recorder._INSTANCE._apply_update(-1) - - -def create_engine_test(*args, **kwargs): - """Test version of create_engine that initializes with old schema. - - This simulates an existing db with the old schema. - """ - engine = create_engine(*args, **kwargs) - models_original.Base.metadata.create_all(engine) - return engine - - -class TestMigrateRecorder(BaseTestRecorder): - """Test recorder class that starts with an original schema db.""" - - @patch('sqlalchemy.create_engine', new=create_engine_test) - @patch('homeassistant.components.recorder.Recorder._migrate_schema') - def setUp(self, migrate): # pylint: disable=invalid-name,arguments-differ - """Setup things to be run when tests are started. - - create_engine is patched to create a db that starts with the old - schema. - - _migrate_schema is mocked to ensure it isn't run, so we can test it - below. - """ - super().setUp() - - def test_schema_update_calls(self): # pylint: disable=no-self-use - """Test that schema migrations occurr in correct order.""" - with patch.object(recorder._INSTANCE, '_apply_update') as update: - recorder._INSTANCE._migrate_schema() - update.assert_has_calls([call(version+1) for version in range( - 0, recorder.models.SCHEMA_VERSION)]) - - def test_schema_migrate(self): # pylint: disable=no-self-use - """Test the full schema migration logic. - - We're just testing that the logic can execute successfully here without - throwing exceptions. Maintaining a set of assertions based on schema - inspection could quickly become quite cumbersome. - """ - recorder._INSTANCE._migrate_schema() - @pytest.fixture def hass_recorder(): @@ -262,7 +92,7 @@ def hass_recorder(): init_recorder_component(hass, config) hass.start() hass.block_till_done() - recorder.get_instance().block_till_done() + hass.data[DATA_INSTANCE].block_till_done() return hass yield setup_recorder @@ -275,11 +105,10 @@ def _add_entities(hass, entity_ids): for idx, entity_id in enumerate(entity_ids): hass.states.set(entity_id, 'state{}'.format(idx), attributes) hass.block_till_done() - recorder._INSTANCE.block_till_done() - db_states = recorder.query('States') - states = recorder.execute(db_states) - assert db_states[0].event_id is not None - return states + hass.data[DATA_INSTANCE].block_till_done() + + with session_scope(hass=hass) as session: + return [st.to_native() for st in session.query(States)] # pylint: disable=redefined-outer-name,invalid-name @@ -334,61 +163,3 @@ def test_saving_state_include_domain_exclude_entity(hass_recorder): assert len(states) == 1 assert hass.states.get('test.ok') == states[0] assert hass.states.get('test.ok').state == 'state2' - - -def test_recorder_errors_exceptions(hass_recorder): \ - # pylint: disable=redefined-outer-name - """Test session_scope and get_model errors.""" - # Model cannot be resolved - assert recorder.get_model('dont-exist') is None - - # Verify the instance fails before setup - with pytest.raises(RuntimeError): - recorder.get_instance() - - # Setup the recorder - hass_recorder() - - recorder.get_instance() - - # Verify session scope raises (and prints) an exception - with patch('homeassistant.components.recorder._LOGGER.error') as e_mock, \ - pytest.raises(Exception) as err: - with recorder.session_scope() as session: - session.execute('select * from notthere') - assert e_mock.call_count == 1 - assert recorder.ERROR_QUERY[:-4] in e_mock.call_args[0][0] - assert 'no such table' in str(err.value) - - -def test_recorder_bad_commit(hass_recorder): - """Bad _commit should retry 3 times.""" - hass_recorder() - - def work(session): - """Bad work.""" - session.execute('select * from notthere') - - with patch('homeassistant.components.recorder.time.sleep') as e_mock, \ - recorder.session_scope() as session: - res = recorder._INSTANCE._commit(session, work) - assert res is False - assert e_mock.call_count == 3 - - -def test_recorder_bad_execute(hass_recorder): - """Bad execute, retry 3 times.""" - hass_recorder() - - def to_native(): - """Rasie exception.""" - from sqlalchemy.exc import SQLAlchemyError - raise SQLAlchemyError() - - mck1 = MagicMock() - mck1.to_native = to_native - - with patch('homeassistant.components.recorder.time.sleep') as e_mock: - res = recorder.execute((mck1,)) - assert res == [] - assert e_mock.call_count == 3 diff --git a/tests/components/recorder/test_migrate.py b/tests/components/recorder/test_migrate.py new file mode 100644 index 00000000000..4990cbc00eb --- /dev/null +++ b/tests/components/recorder/test_migrate.py @@ -0,0 +1,67 @@ +"""The tests for the Recorder component.""" +# pylint: disable=protected-access +import asyncio +from unittest.mock import patch, call + +import pytest +from sqlalchemy import create_engine + +from homeassistant.bootstrap import async_setup_component +from homeassistant.components.recorder import wait_connection_ready, migration +from homeassistant.components.recorder.models import SCHEMA_VERSION +from homeassistant.components.recorder.const import DATA_INSTANCE +from tests.components.recorder import models_original + + +def create_engine_test(*args, **kwargs): + """Test version of create_engine that initializes with old schema. + + This simulates an existing db with the old schema. + """ + engine = create_engine(*args, **kwargs) + models_original.Base.metadata.create_all(engine) + return engine + + +@asyncio.coroutine +def test_schema_update_calls(hass): + """Test that schema migrations occurr in correct order.""" + with patch('sqlalchemy.create_engine', new=create_engine_test), \ + patch('homeassistant.components.recorder.migration._apply_update') as \ + update: + yield from async_setup_component(hass, 'recorder', { + 'recorder': { + 'db_url': 'sqlite://' + } + }) + yield from wait_connection_ready(hass) + + update.assert_has_calls([ + call(hass.data[DATA_INSTANCE].engine, version+1) for version + in range(0, SCHEMA_VERSION)]) + + +@asyncio.coroutine +def test_schema_migrate(hass): + """Test the full schema migration logic. + + We're just testing that the logic can execute successfully here without + throwing exceptions. Maintaining a set of assertions based on schema + inspection could quickly become quite cumbersome. + """ + with patch('sqlalchemy.create_engine', new=create_engine_test), \ + patch('homeassistant.components.recorder.Recorder._setup_run') as \ + setup_run: + yield from async_setup_component(hass, 'recorder', { + 'recorder': { + 'db_url': 'sqlite://' + } + }) + yield from wait_connection_ready(hass) + assert setup_run.called + + +def test_invalid_update(): + """Test that an invalid new version raises an exception.""" + with pytest.raises(ValueError): + migration._apply_update(None, -1) diff --git a/tests/components/recorder/test_purge.py b/tests/components/recorder/test_purge.py new file mode 100644 index 00000000000..1a52e0503bb --- /dev/null +++ b/tests/components/recorder/test_purge.py @@ -0,0 +1,109 @@ +"""Test data purging.""" +import json +from datetime import datetime, timedelta +import unittest + +from homeassistant.components import recorder +from homeassistant.components.recorder.const import DATA_INSTANCE +from homeassistant.components.recorder.purge import purge_old_data +from homeassistant.components.recorder.models import States, Events +from homeassistant.components.recorder.util import session_scope +from tests.common import get_test_home_assistant, init_recorder_component + + +class TestRecorderPurge(unittest.TestCase): + """Base class for common recorder tests.""" + + def setUp(self): # pylint: disable=invalid-name + """Setup things to be run when tests are started.""" + self.hass = get_test_home_assistant() + init_recorder_component(self.hass) + self.hass.start() + + def tearDown(self): # pylint: disable=invalid-name + """Stop everything that was started.""" + self.hass.stop() + + def _add_test_states(self): + """Add multiple states to the db for testing.""" + now = datetime.now() + five_days_ago = now - timedelta(days=5) + attributes = {'test_attr': 5, 'test_attr_10': 'nice'} + + self.hass.block_till_done() + self.hass.data[DATA_INSTANCE].block_till_done() + + with recorder.session_scope(hass=self.hass) as session: + for event_id in range(5): + if event_id < 3: + timestamp = five_days_ago + state = 'purgeme' + else: + timestamp = now + state = 'dontpurgeme' + + session.add(States( + entity_id='test.recorder2', + domain='sensor', + state=state, + attributes=json.dumps(attributes), + last_changed=timestamp, + last_updated=timestamp, + created=timestamp, + event_id=event_id + 1000 + )) + + def _add_test_events(self): + """Add a few events for testing.""" + now = datetime.now() + five_days_ago = now - timedelta(days=5) + event_data = {'test_attr': 5, 'test_attr_10': 'nice'} + + self.hass.block_till_done() + self.hass.data[DATA_INSTANCE].block_till_done() + + with recorder.session_scope(hass=self.hass) as session: + for event_id in range(5): + if event_id < 2: + timestamp = five_days_ago + event_type = 'EVENT_TEST_PURGE' + else: + timestamp = now + event_type = 'EVENT_TEST' + + session.add(Events( + event_type=event_type, + event_data=json.dumps(event_data), + origin='LOCAL', + created=timestamp, + time_fired=timestamp, + )) + + def test_purge_old_states(self): + """Test deleting old states.""" + self._add_test_states() + # make sure we start with 5 states + with session_scope(hass=self.hass) as session: + states = session.query(States) + self.assertEqual(states.count(), 5) + + # run purge_old_data() + purge_old_data(self.hass.data[DATA_INSTANCE], 4) + + # we should only have 2 states left after purging + self.assertEqual(states.count(), 2) + + def test_purge_old_events(self): + """Test deleting old events.""" + self._add_test_events() + + with session_scope(hass=self.hass) as session: + events = session.query(Events).filter( + Events.event_type.like("EVENT_TEST%")) + self.assertEqual(events.count(), 5) + + # run purge_old_data() + purge_old_data(self.hass.data[DATA_INSTANCE], 4) + + # now we should only have 3 events left + self.assertEqual(events.count(), 3) diff --git a/tests/components/recorder/test_util.py b/tests/components/recorder/test_util.py new file mode 100644 index 00000000000..ad130b1ca91 --- /dev/null +++ b/tests/components/recorder/test_util.py @@ -0,0 +1,59 @@ +"""Test util methods.""" +from unittest.mock import patch, MagicMock + +import pytest + +from homeassistant.components.recorder import util +from homeassistant.components.recorder.const import DATA_INSTANCE +from tests.common import get_test_home_assistant, init_recorder_component + + +@pytest.fixture +def hass_recorder(): + """HASS fixture with in-memory recorder.""" + hass = get_test_home_assistant() + + def setup_recorder(config=None): + """Setup with params.""" + init_recorder_component(hass, config) + hass.start() + hass.block_till_done() + hass.data[DATA_INSTANCE].block_till_done() + return hass + + yield setup_recorder + hass.stop() + + +def test_recorder_bad_commit(hass_recorder): + """Bad _commit should retry 3 times.""" + hass = hass_recorder() + + def work(session): + """Bad work.""" + session.execute('select * from notthere') + + with patch('homeassistant.components.recorder.time.sleep') as e_mock, \ + util.session_scope(hass=hass) as session: + res = util.commit(session, work) + assert res is False + assert e_mock.call_count == 3 + + +def test_recorder_bad_execute(hass_recorder): + """Bad execute, retry 3 times.""" + from sqlalchemy.exc import SQLAlchemyError + hass_recorder() + + def to_native(): + """Rasie exception.""" + raise SQLAlchemyError() + + mck1 = MagicMock() + mck1.to_native = to_native + + with pytest.raises(SQLAlchemyError), \ + patch('homeassistant.components.recorder.time.sleep') as e_mock: + util.execute((mck1,)) + + assert e_mock.call_count == 2 diff --git a/tests/components/sensor/test_history_stats.py b/tests/components/sensor/test_history_stats.py index d4f1cbcbe9a..52a229f43c8 100644 --- a/tests/components/sensor/test_history_stats.py +++ b/tests/components/sensor/test_history_stats.py @@ -5,7 +5,6 @@ import unittest from unittest.mock import patch from homeassistant.bootstrap import setup_component -import homeassistant.components.recorder as recorder from homeassistant.components.sensor.history_stats import HistoryStatsSensor import homeassistant.core as ha from homeassistant.helpers.template import Template @@ -207,6 +206,3 @@ class TestHistoryStatsSensor(unittest.TestCase): """Initialize the recorder.""" init_recorder_component(self.hass) self.hass.start() - recorder.get_instance().block_till_db_ready() - self.hass.block_till_done() - recorder.get_instance().block_till_done() diff --git a/tests/components/test_history.py b/tests/components/test_history.py index 65870d1450f..7324a5e9b32 100644 --- a/tests/components/test_history.py +++ b/tests/components/test_history.py @@ -29,13 +29,12 @@ class TestComponentHistory(unittest.TestCase): """Initialize the recorder.""" init_recorder_component(self.hass) self.hass.start() - recorder.get_instance().block_till_db_ready() self.wait_recording_done() def wait_recording_done(self): """Block till recording is done.""" self.hass.block_till_done() - recorder.get_instance().block_till_done() + self.hass.data[recorder.DATA_INSTANCE].block_till_done() def test_setup(self): """Test setup method of history.""" @@ -87,12 +86,13 @@ class TestComponentHistory(unittest.TestCase): # Get states returns everything before POINT self.assertEqual(states, - sorted(history.get_states(future), + sorted(history.get_states(self.hass, future), key=lambda state: state.entity_id)) # Test get_state here because we have a DB setup self.assertEqual( - states[0], history.get_state(future, states[0].entity_id)) + states[0], history.get_state(self.hass, future, + states[0].entity_id)) def test_state_changes_during_period(self): """Test state change during period.""" @@ -128,7 +128,8 @@ class TestComponentHistory(unittest.TestCase): set_state('Netflix') set_state('Plex') - hist = history.state_changes_during_period(start, end, entity_id) + hist = history.state_changes_during_period( + self.hass, start, end, entity_id) self.assertEqual(states, hist[entity_id]) @@ -141,7 +142,7 @@ class TestComponentHistory(unittest.TestCase): """ zero, four, states = self.record_states() hist = history.get_significant_states( - zero, four, filters=history.Filters()) + self.hass, zero, four, filters=history.Filters()) assert states == hist def test_get_significant_states_entity_id(self): @@ -153,7 +154,7 @@ class TestComponentHistory(unittest.TestCase): del states['script.can_cancel_this_one'] hist = history.get_significant_states( - zero, four, 'media_player.test', + self.hass, zero, four, 'media_player.test', filters=history.Filters()) assert states == hist @@ -355,7 +356,8 @@ class TestComponentHistory(unittest.TestCase): filters.included_entities = include[history.CONF_ENTITIES] filters.included_domains = include[history.CONF_DOMAINS] - hist = history.get_significant_states(zero, four, filters=filters) + hist = history.get_significant_states( + self.hass, zero, four, filters=filters) assert states == hist def record_states(self): diff --git a/tests/components/test_logbook.py b/tests/components/test_logbook.py index 69497ef8388..13735df0a11 100644 --- a/tests/components/test_logbook.py +++ b/tests/components/test_logbook.py @@ -138,7 +138,7 @@ class TestComponentLogbook(unittest.TestCase): eventA.data['old_state'] = None events = logbook._exclude_events((ha.Event(EVENT_HOMEASSISTANT_STOP), - eventA, eventB), self.EMPTY_CONFIG) + eventA, eventB), {}) entries = list(logbook.humanify(events)) self.assertEqual(2, len(entries)) @@ -160,7 +160,7 @@ class TestComponentLogbook(unittest.TestCase): eventA.data['new_state'] = None events = logbook._exclude_events((ha.Event(EVENT_HOMEASSISTANT_STOP), - eventA, eventB), self.EMPTY_CONFIG) + eventA, eventB), {}) entries = list(logbook.humanify(events)) self.assertEqual(2, len(entries)) @@ -182,7 +182,7 @@ class TestComponentLogbook(unittest.TestCase): eventB = self.create_state_changed_event(pointB, entity_id2, 20) events = logbook._exclude_events((ha.Event(EVENT_HOMEASSISTANT_STOP), - eventA, eventB), self.EMPTY_CONFIG) + eventA, eventB), {}) entries = list(logbook.humanify(events)) self.assertEqual(2, len(entries)) @@ -206,8 +206,9 @@ class TestComponentLogbook(unittest.TestCase): ha.DOMAIN: {}, logbook.DOMAIN: {logbook.CONF_EXCLUDE: { logbook.CONF_ENTITIES: [entity_id, ]}}}) - events = logbook._exclude_events((ha.Event(EVENT_HOMEASSISTANT_STOP), - eventA, eventB), config) + events = logbook._exclude_events( + (ha.Event(EVENT_HOMEASSISTANT_STOP), eventA, eventB), + config[logbook.DOMAIN]) entries = list(logbook.humanify(events)) self.assertEqual(2, len(entries)) @@ -231,8 +232,9 @@ class TestComponentLogbook(unittest.TestCase): ha.DOMAIN: {}, logbook.DOMAIN: {logbook.CONF_EXCLUDE: { logbook.CONF_DOMAINS: ['switch', ]}}}) - events = logbook._exclude_events((ha.Event(EVENT_HOMEASSISTANT_START), - eventA, eventB), config) + events = logbook._exclude_events( + (ha.Event(EVENT_HOMEASSISTANT_START), eventA, eventB), + config[logbook.DOMAIN]) entries = list(logbook.humanify(events)) self.assertEqual(2, len(entries)) @@ -267,8 +269,9 @@ class TestComponentLogbook(unittest.TestCase): ha.DOMAIN: {}, logbook.DOMAIN: {logbook.CONF_EXCLUDE: { logbook.CONF_ENTITIES: [entity_id, ]}}}) - events = logbook._exclude_events((ha.Event(EVENT_HOMEASSISTANT_STOP), - eventA, eventB), config) + events = logbook._exclude_events( + (ha.Event(EVENT_HOMEASSISTANT_STOP), eventA, eventB), + config[logbook.DOMAIN]) entries = list(logbook.humanify(events)) self.assertEqual(2, len(entries)) @@ -292,8 +295,9 @@ class TestComponentLogbook(unittest.TestCase): ha.DOMAIN: {}, logbook.DOMAIN: {logbook.CONF_INCLUDE: { logbook.CONF_ENTITIES: [entity_id2, ]}}}) - events = logbook._exclude_events((ha.Event(EVENT_HOMEASSISTANT_STOP), - eventA, eventB), config) + events = logbook._exclude_events( + (ha.Event(EVENT_HOMEASSISTANT_STOP), eventA, eventB), + config[logbook.DOMAIN]) entries = list(logbook.humanify(events)) self.assertEqual(2, len(entries)) @@ -317,8 +321,9 @@ class TestComponentLogbook(unittest.TestCase): ha.DOMAIN: {}, logbook.DOMAIN: {logbook.CONF_INCLUDE: { logbook.CONF_DOMAINS: ['sensor', ]}}}) - events = logbook._exclude_events((ha.Event(EVENT_HOMEASSISTANT_START), - eventA, eventB), config) + events = logbook._exclude_events( + (ha.Event(EVENT_HOMEASSISTANT_START), eventA, eventB), + config[logbook.DOMAIN]) entries = list(logbook.humanify(events)) self.assertEqual(2, len(entries)) @@ -350,9 +355,9 @@ class TestComponentLogbook(unittest.TestCase): logbook.CONF_EXCLUDE: { logbook.CONF_DOMAINS: ['switch', ], logbook.CONF_ENTITIES: ['sensor.bli', ]}}}) - events = logbook._exclude_events((ha.Event(EVENT_HOMEASSISTANT_START), - eventA1, eventA2, eventA3, - eventB1, eventB2), config) + events = logbook._exclude_events( + (ha.Event(EVENT_HOMEASSISTANT_START), eventA1, eventA2, eventA3, + eventB1, eventB2), config[logbook.DOMAIN]) entries = list(logbook.humanify(events)) self.assertEqual(3, len(entries)) diff --git a/tests/helpers/test_restore_state.py b/tests/helpers/test_restore_state.py index 3a4c058f853..59598823911 100644 --- a/tests/helpers/test_restore_state.py +++ b/tests/helpers/test_restore_state.py @@ -10,6 +10,7 @@ import homeassistant.util.dt as dt_util from homeassistant.components import input_boolean, recorder from homeassistant.helpers.restore_state import ( async_get_last_state, DATA_RESTORE_CACHE) +from homeassistant.components.recorder.models import RecorderRuns, States from tests.common import ( get_test_home_assistant, mock_coro, init_recorder_component) @@ -31,7 +32,7 @@ def test_caching_data(hass): return_value=MagicMock(end=dt_util.utcnow())), \ patch('homeassistant.helpers.restore_state.get_states', return_value=states), \ - patch('homeassistant.helpers.restore_state.async_get_instance', + patch('homeassistant.helpers.restore_state.wait_connection_ready', return_value=mock_coro()): state = yield from async_get_last_state(hass, 'input_boolean.b1') @@ -49,33 +50,29 @@ def test_caching_data(hass): assert DATA_RESTORE_CACHE not in hass.data -def _add_data_in_last_run(entities): +def _add_data_in_last_run(hass, entities): """Add test data in the last recorder_run.""" # pylint: disable=protected-access t_now = dt_util.utcnow() - timedelta(minutes=10) t_min_1 = t_now - timedelta(minutes=20) t_min_2 = t_now - timedelta(minutes=30) - recorder_runs = recorder.get_model('RecorderRuns') - states = recorder.get_model('States') - with recorder.session_scope() as session: - run = recorder_runs( + with recorder.session_scope(hass=hass) as session: + session.add(RecorderRuns( start=t_min_2, end=t_now, created=t_min_2 - ) - recorder._INSTANCE._commit(session, run) + )) for entity_id, state in entities.items(): - dbstate = states( + session.add(States( entity_id=entity_id, domain=split_entity_id(entity_id)[0], state=state, attributes='{}', last_changed=t_min_1, last_updated=t_min_1, - created=t_min_1) - recorder._INSTANCE._commit(session, dbstate) + created=t_min_1)) def test_filling_the_cache(): @@ -88,7 +85,7 @@ def test_filling_the_cache(): init_recorder_component(hass) - _add_data_in_last_run({ + _add_data_in_last_run(hass, { test_entity_id1: 'on', test_entity_id2: 'off', })