Optimize _get_states_with_session (#56734)

* Optimize _get_states_with_session

* Move custom filters to derived table

* Remove useless derived table

* Filter old states after grouping

* Split query

* Add comments

* Simplify state update period criteria

* Only apply custom filters if we didn't get an include list of entities

Co-authored-by: J. Nick Koston <nick@koston.org>
This commit is contained in:
Erik Montnemery 2021-09-29 17:08:27 +02:00 committed by GitHub
parent daebc34f4d
commit 00651a4055
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 70 additions and 38 deletions

View File

@ -80,6 +80,11 @@ def _get_significant_states(
""" """
Return states changes during UTC period start_time - end_time. Return states changes during UTC period start_time - end_time.
entity_ids is an optional iterable of entities to include in the results.
filters is an optional SQLAlchemy filter which will be applied to the database
queries unless entity_ids is given, in which case its ignored.
Significant states are all states where there is a state change, Significant states are all states where there is a state change,
as well as all states from certain domains (for instance as well as all states from certain domains (for instance
thermostat so that we get current temperature in our graphs). thermostat so that we get current temperature in our graphs).
@ -240,47 +245,63 @@ def _get_states_with_session(
if run is None: if run is None:
return [] return []
# We have more than one entity to look at (most commonly we want # We have more than one entity to look at so we need to do a query on states
# all entities,) so we need to do a search on all states since the # since the last recorder run started.
# last recorder run started.
query = session.query(*QUERY_STATES) query = session.query(*QUERY_STATES)
most_recent_states_by_date = session.query(
States.entity_id.label("max_entity_id"),
func.max(States.last_updated).label("max_last_updated"),
).filter(
(States.last_updated >= run.start) & (States.last_updated < utc_point_in_time)
)
if entity_ids: if entity_ids:
most_recent_states_by_date.filter(States.entity_id.in_(entity_ids)) # We got an include-list of entities, accelerate the query by filtering already
# in the inner query.
most_recent_states_by_date = most_recent_states_by_date.group_by(States.entity_id) most_recent_state_ids = (
session.query(
most_recent_states_by_date = most_recent_states_by_date.subquery() func.max(States.state_id).label("max_state_id"),
)
most_recent_state_ids = session.query( .filter(
func.max(States.state_id).label("max_state_id") (States.last_updated >= run.start)
).join( & (States.last_updated < utc_point_in_time)
most_recent_states_by_date, )
and_( .filter(States.entity_id.in_(entity_ids))
States.entity_id == most_recent_states_by_date.c.max_entity_id, )
States.last_updated == most_recent_states_by_date.c.max_last_updated, most_recent_state_ids = most_recent_state_ids.group_by(States.entity_id)
), most_recent_state_ids = most_recent_state_ids.subquery()
) query = query.join(
most_recent_state_ids,
most_recent_state_ids = most_recent_state_ids.group_by(States.entity_id) States.state_id == most_recent_state_ids.c.max_state_id,
)
most_recent_state_ids = most_recent_state_ids.subquery()
query = query.join(
most_recent_state_ids,
States.state_id == most_recent_state_ids.c.max_state_id,
)
if entity_ids is not None:
query = query.filter(States.entity_id.in_(entity_ids))
else: else:
# We did not get an include-list of entities, query all states in the inner
# query, then filter out unwanted domains as well as applying the custom filter.
# This filtering can't be done in the inner query because the domain column is
# not indexed and we can't control what's in the custom filter.
most_recent_states_by_date = (
session.query(
States.entity_id.label("max_entity_id"),
func.max(States.last_updated).label("max_last_updated"),
)
.filter(
(States.last_updated >= run.start)
& (States.last_updated < utc_point_in_time)
)
.group_by(States.entity_id)
.subquery()
)
most_recent_state_ids = (
session.query(func.max(States.state_id).label("max_state_id"))
.join(
most_recent_states_by_date,
and_(
States.entity_id == most_recent_states_by_date.c.max_entity_id,
States.last_updated
== most_recent_states_by_date.c.max_last_updated,
),
)
.group_by(States.entity_id)
.subquery()
)
query = query.join(
most_recent_state_ids,
States.state_id == most_recent_state_ids.c.max_state_id,
)
query = query.filter(~States.domain.in_(IGNORE_DOMAINS)) query = query.filter(~States.domain.in_(IGNORE_DOMAINS))
if filters: if filters:
query = filters.apply(query) query = filters.apply(query)

View File

@ -48,13 +48,24 @@ def test_get_states(hass_recorder):
wait_recording_done(hass) wait_recording_done(hass)
# Get states returns everything before POINT # Get states returns everything before POINT for all entities
for state1, state2 in zip( for state1, state2 in zip(
states, states,
sorted(history.get_states(hass, future), key=lambda state: state.entity_id), sorted(history.get_states(hass, future), key=lambda state: state.entity_id),
): ):
assert state1 == state2 assert state1 == state2
# Get states returns everything before POINT for tested entities
entities = [f"test.point_in_time_{i % 5}" for i in range(5)]
for state1, state2 in zip(
states,
sorted(
history.get_states(hass, future, entities),
key=lambda state: state.entity_id,
),
):
assert state1 == state2
# Test get_state here because we have a DB setup # Test get_state here because we have a DB setup
assert states[0] == history.get_state(hass, future, states[0].entity_id) assert states[0] == history.get_state(hass, future, states[0].entity_id)