Optimize _get_states_with_session (#56734)

* Optimize _get_states_with_session

* Move custom filters to derived table

* Remove useless derived table

* Filter old states after grouping

* Split query

* Add comments

* Simplify state update period criteria

* Only apply custom filters if we didn't get an include list of entities

Co-authored-by: J. Nick Koston <nick@koston.org>
This commit is contained in:
Erik Montnemery 2021-09-29 17:08:27 +02:00 committed by GitHub
parent daebc34f4d
commit 00651a4055
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 70 additions and 38 deletions

View File

@ -80,6 +80,11 @@ def _get_significant_states(
"""
Return states changes during UTC period start_time - end_time.
entity_ids is an optional iterable of entities to include in the results.
filters is an optional SQLAlchemy filter which will be applied to the database
queries unless entity_ids is given, in which case its ignored.
Significant states are all states where there is a state change,
as well as all states from certain domains (for instance
thermostat so that we get current temperature in our graphs).
@ -240,47 +245,63 @@ def _get_states_with_session(
if run is None:
return []
# We have more than one entity to look at (most commonly we want
# all entities,) so we need to do a search on all states since the
# last recorder run started.
# We have more than one entity to look at so we need to do a query on states
# since the last recorder run started.
query = session.query(*QUERY_STATES)
most_recent_states_by_date = session.query(
States.entity_id.label("max_entity_id"),
func.max(States.last_updated).label("max_last_updated"),
).filter(
(States.last_updated >= run.start) & (States.last_updated < utc_point_in_time)
)
if entity_ids:
most_recent_states_by_date.filter(States.entity_id.in_(entity_ids))
most_recent_states_by_date = most_recent_states_by_date.group_by(States.entity_id)
most_recent_states_by_date = most_recent_states_by_date.subquery()
most_recent_state_ids = session.query(
func.max(States.state_id).label("max_state_id")
).join(
most_recent_states_by_date,
and_(
States.entity_id == most_recent_states_by_date.c.max_entity_id,
States.last_updated == most_recent_states_by_date.c.max_last_updated,
),
# We got an include-list of entities, accelerate the query by filtering already
# in the inner query.
most_recent_state_ids = (
session.query(
func.max(States.state_id).label("max_state_id"),
)
.filter(
(States.last_updated >= run.start)
& (States.last_updated < utc_point_in_time)
)
.filter(States.entity_id.in_(entity_ids))
)
most_recent_state_ids = most_recent_state_ids.group_by(States.entity_id)
most_recent_state_ids = most_recent_state_ids.subquery()
query = query.join(
most_recent_state_ids,
States.state_id == most_recent_state_ids.c.max_state_id,
)
if entity_ids is not None:
query = query.filter(States.entity_id.in_(entity_ids))
else:
# We did not get an include-list of entities, query all states in the inner
# query, then filter out unwanted domains as well as applying the custom filter.
# This filtering can't be done in the inner query because the domain column is
# not indexed and we can't control what's in the custom filter.
most_recent_states_by_date = (
session.query(
States.entity_id.label("max_entity_id"),
func.max(States.last_updated).label("max_last_updated"),
)
.filter(
(States.last_updated >= run.start)
& (States.last_updated < utc_point_in_time)
)
.group_by(States.entity_id)
.subquery()
)
most_recent_state_ids = (
session.query(func.max(States.state_id).label("max_state_id"))
.join(
most_recent_states_by_date,
and_(
States.entity_id == most_recent_states_by_date.c.max_entity_id,
States.last_updated
== most_recent_states_by_date.c.max_last_updated,
),
)
.group_by(States.entity_id)
.subquery()
)
query = query.join(
most_recent_state_ids,
States.state_id == most_recent_state_ids.c.max_state_id,
)
query = query.filter(~States.domain.in_(IGNORE_DOMAINS))
if filters:
query = filters.apply(query)

View File

@ -48,13 +48,24 @@ def test_get_states(hass_recorder):
wait_recording_done(hass)
# Get states returns everything before POINT
# Get states returns everything before POINT for all entities
for state1, state2 in zip(
states,
sorted(history.get_states(hass, future), key=lambda state: state.entity_id),
):
assert state1 == state2
# Get states returns everything before POINT for tested entities
entities = [f"test.point_in_time_{i % 5}" for i in range(5)]
for state1, state2 in zip(
states,
sorted(
history.get_states(hass, future, entities),
key=lambda state: state.entity_id,
),
):
assert state1 == state2
# Test get_state here because we have a DB setup
assert states[0] == history.get_state(hass, future, states[0].entity_id)