From 00651a40554de2003fb0824bb524c38078d012a2 Mon Sep 17 00:00:00 2001 From: Erik Montnemery Date: Wed, 29 Sep 2021 17:08:27 +0200 Subject: [PATCH] Optimize _get_states_with_session (#56734) * Optimize _get_states_with_session * Move custom filters to derived table * Remove useless derived table * Filter old states after grouping * Split query * Add comments * Simplify state update period criteria * Only apply custom filters if we didn't get an include list of entities Co-authored-by: J. Nick Koston --- homeassistant/components/recorder/history.py | 95 ++++++++++++-------- tests/components/recorder/test_history.py | 13 ++- 2 files changed, 70 insertions(+), 38 deletions(-) diff --git a/homeassistant/components/recorder/history.py b/homeassistant/components/recorder/history.py index 6c89fef2be3..36a4f6d0696 100644 --- a/homeassistant/components/recorder/history.py +++ b/homeassistant/components/recorder/history.py @@ -80,6 +80,11 @@ def _get_significant_states( """ Return states changes during UTC period start_time - end_time. + entity_ids is an optional iterable of entities to include in the results. + + filters is an optional SQLAlchemy filter which will be applied to the database + queries unless entity_ids is given, in which case its ignored. + Significant states are all states where there is a state change, as well as all states from certain domains (for instance thermostat so that we get current temperature in our graphs). @@ -240,47 +245,63 @@ def _get_states_with_session( if run is None: return [] - # We have more than one entity to look at (most commonly we want - # all entities,) so we need to do a search on all states since the - # last recorder run started. + # We have more than one entity to look at so we need to do a query on states + # since the last recorder run started. query = session.query(*QUERY_STATES) - most_recent_states_by_date = session.query( - States.entity_id.label("max_entity_id"), - func.max(States.last_updated).label("max_last_updated"), - ).filter( - (States.last_updated >= run.start) & (States.last_updated < utc_point_in_time) - ) - if entity_ids: - most_recent_states_by_date.filter(States.entity_id.in_(entity_ids)) - - most_recent_states_by_date = most_recent_states_by_date.group_by(States.entity_id) - - most_recent_states_by_date = most_recent_states_by_date.subquery() - - most_recent_state_ids = session.query( - func.max(States.state_id).label("max_state_id") - ).join( - most_recent_states_by_date, - and_( - States.entity_id == most_recent_states_by_date.c.max_entity_id, - States.last_updated == most_recent_states_by_date.c.max_last_updated, - ), - ) - - most_recent_state_ids = most_recent_state_ids.group_by(States.entity_id) - - most_recent_state_ids = most_recent_state_ids.subquery() - - query = query.join( - most_recent_state_ids, - States.state_id == most_recent_state_ids.c.max_state_id, - ) - - if entity_ids is not None: - query = query.filter(States.entity_id.in_(entity_ids)) + # We got an include-list of entities, accelerate the query by filtering already + # in the inner query. + most_recent_state_ids = ( + session.query( + func.max(States.state_id).label("max_state_id"), + ) + .filter( + (States.last_updated >= run.start) + & (States.last_updated < utc_point_in_time) + ) + .filter(States.entity_id.in_(entity_ids)) + ) + most_recent_state_ids = most_recent_state_ids.group_by(States.entity_id) + most_recent_state_ids = most_recent_state_ids.subquery() + query = query.join( + most_recent_state_ids, + States.state_id == most_recent_state_ids.c.max_state_id, + ) else: + # We did not get an include-list of entities, query all states in the inner + # query, then filter out unwanted domains as well as applying the custom filter. + # This filtering can't be done in the inner query because the domain column is + # not indexed and we can't control what's in the custom filter. + most_recent_states_by_date = ( + session.query( + States.entity_id.label("max_entity_id"), + func.max(States.last_updated).label("max_last_updated"), + ) + .filter( + (States.last_updated >= run.start) + & (States.last_updated < utc_point_in_time) + ) + .group_by(States.entity_id) + .subquery() + ) + most_recent_state_ids = ( + session.query(func.max(States.state_id).label("max_state_id")) + .join( + most_recent_states_by_date, + and_( + States.entity_id == most_recent_states_by_date.c.max_entity_id, + States.last_updated + == most_recent_states_by_date.c.max_last_updated, + ), + ) + .group_by(States.entity_id) + .subquery() + ) + query = query.join( + most_recent_state_ids, + States.state_id == most_recent_state_ids.c.max_state_id, + ) query = query.filter(~States.domain.in_(IGNORE_DOMAINS)) if filters: query = filters.apply(query) diff --git a/tests/components/recorder/test_history.py b/tests/components/recorder/test_history.py index b2940f2bb39..67a666c934f 100644 --- a/tests/components/recorder/test_history.py +++ b/tests/components/recorder/test_history.py @@ -48,13 +48,24 @@ def test_get_states(hass_recorder): wait_recording_done(hass) - # Get states returns everything before POINT + # Get states returns everything before POINT for all entities for state1, state2 in zip( states, sorted(history.get_states(hass, future), key=lambda state: state.entity_id), ): assert state1 == state2 + # Get states returns everything before POINT for tested entities + entities = [f"test.point_in_time_{i % 5}" for i in range(5)] + for state1, state2 in zip( + states, + sorted( + history.get_states(hass, future, entities), + key=lambda state: state.entity_id, + ), + ): + assert state1 == state2 + # Test get_state here because we have a DB setup assert states[0] == history.get_state(hass, future, states[0].entity_id)