Fix most recent states query not using the index for group by (#88461)

* Fix most recent states query not using the index for group by

fixes #87851

* Apply suggestions from code review

* reduce
This commit is contained in:
J. Nick Koston 2023-02-19 20:05:45 -06:00 committed by GitHub
parent 9a6bcc2b63
commit eac9ad8437
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -583,7 +583,49 @@ def get_last_state_changes(
) )
def _get_states_for_entites_stmt( def _generate_most_recent_states_for_entities_by_date(
schema_version: int,
run_start: datetime,
utc_point_in_time: datetime,
entity_ids: list[str],
) -> Subquery:
"""Generate the sub query for the most recent states for specific entities by date."""
if schema_version >= 31:
run_start_ts = process_timestamp(run_start).timestamp()
utc_point_in_time_ts = dt_util.utc_to_timestamp(utc_point_in_time)
return (
select(
States.entity_id.label("max_entity_id"),
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
# pylint: disable-next=not-callable
func.max(States.last_updated_ts).label("max_last_updated"),
)
.filter(
(States.last_updated_ts >= run_start_ts)
& (States.last_updated_ts < utc_point_in_time_ts)
)
.filter(States.entity_id.in_(entity_ids))
.group_by(States.entity_id)
.subquery()
)
return (
select(
States.entity_id.label("max_entity_id"),
# https://github.com/sqlalchemy/sqlalchemy/issues/9189
# pylint: disable-next=not-callable
func.max(States.last_updated).label("max_last_updated"),
)
.filter(
(States.last_updated >= run_start)
& (States.last_updated < utc_point_in_time)
)
.filter(States.entity_id.in_(entity_ids))
.group_by(States.entity_id)
.subquery()
)
def _get_states_for_entities_stmt(
schema_version: int, schema_version: int,
run_start: datetime, run_start: datetime,
utc_point_in_time: datetime, utc_point_in_time: datetime,
@ -594,41 +636,32 @@ def _get_states_for_entites_stmt(
stmt, join_attributes = lambda_stmt_and_join_attributes( stmt, join_attributes = lambda_stmt_and_join_attributes(
schema_version, no_attributes, include_last_changed=True schema_version, no_attributes, include_last_changed=True
) )
most_recent_states_for_entities_by_date = (
_generate_most_recent_states_for_entities_by_date(
schema_version, run_start, utc_point_in_time, entity_ids
)
)
# We got an include-list of entities, accelerate the query by filtering already # We got an include-list of entities, accelerate the query by filtering already
# in the inner query. # in the inner query.
if schema_version >= 31: if schema_version >= 31:
run_start_ts = process_timestamp(run_start).timestamp() stmt += lambda q: q.join(
utc_point_in_time_ts = dt_util.utc_to_timestamp(utc_point_in_time) most_recent_states_for_entities_by_date,
stmt += lambda q: q.where( and_(
States.state_id States.entity_id
== ( == most_recent_states_for_entities_by_date.c.max_entity_id,
# https://github.com/sqlalchemy/sqlalchemy/issues/9189 States.last_updated_ts
# pylint: disable-next=not-callable == most_recent_states_for_entities_by_date.c.max_last_updated,
select(func.max(States.state_id).label("max_state_id")) ),
.filter(
(States.last_updated_ts >= run_start_ts)
& (States.last_updated_ts < utc_point_in_time_ts)
)
.filter(States.entity_id.in_(entity_ids))
.group_by(States.entity_id)
.subquery()
).c.max_state_id
) )
else: else:
stmt += lambda q: q.where( stmt += lambda q: q.join(
States.state_id most_recent_states_for_entities_by_date,
== ( and_(
# https://github.com/sqlalchemy/sqlalchemy/issues/9189 States.entity_id
# pylint: disable-next=not-callable == most_recent_states_for_entities_by_date.c.max_entity_id,
select(func.max(States.state_id).label("max_state_id")) States.last_updated
.filter( == most_recent_states_for_entities_by_date.c.max_last_updated,
(States.last_updated >= run_start) ),
& (States.last_updated < utc_point_in_time)
)
.filter(States.entity_id.in_(entity_ids))
.group_by(States.entity_id)
.subquery()
).c.max_state_id
) )
if join_attributes: if join_attributes:
stmt += lambda q: q.outerjoin( stmt += lambda q: q.outerjoin(
@ -642,7 +675,7 @@ def _generate_most_recent_states_by_date(
run_start: datetime, run_start: datetime,
utc_point_in_time: datetime, utc_point_in_time: datetime,
) -> Subquery: ) -> Subquery:
"""Generate the sub query for the most recent states by data.""" """Generate the sub query for the most recent states by date."""
if schema_version >= 31: if schema_version >= 31:
run_start_ts = process_timestamp(run_start).timestamp() run_start_ts = process_timestamp(run_start).timestamp()
utc_point_in_time_ts = dt_util.utc_to_timestamp(utc_point_in_time) utc_point_in_time_ts = dt_util.utc_to_timestamp(utc_point_in_time)
@ -695,42 +728,20 @@ def _get_states_for_all_stmt(
schema_version, run_start, utc_point_in_time schema_version, run_start, utc_point_in_time
) )
if schema_version >= 31: if schema_version >= 31:
stmt += lambda q: q.where( stmt += lambda q: q.join(
States.state_id most_recent_states_by_date,
== ( and_(
# https://github.com/sqlalchemy/sqlalchemy/issues/9189 States.entity_id == most_recent_states_by_date.c.max_entity_id,
# pylint: disable-next=not-callable States.last_updated_ts == most_recent_states_by_date.c.max_last_updated,
select(func.max(States.state_id).label("max_state_id")) ),
.join(
most_recent_states_by_date,
and_(
States.entity_id == most_recent_states_by_date.c.max_entity_id,
States.last_updated_ts
== most_recent_states_by_date.c.max_last_updated,
),
)
.group_by(States.entity_id)
.subquery()
).c.max_state_id,
) )
else: else:
stmt += lambda q: q.where( stmt += lambda q: q.join(
States.state_id most_recent_states_by_date,
== ( and_(
# https://github.com/sqlalchemy/sqlalchemy/issues/9189 States.entity_id == most_recent_states_by_date.c.max_entity_id,
# pylint: disable-next=not-callable States.last_updated == most_recent_states_by_date.c.max_last_updated,
select(func.max(States.state_id).label("max_state_id")) ),
.join(
most_recent_states_by_date,
and_(
States.entity_id == most_recent_states_by_date.c.max_entity_id,
States.last_updated
== most_recent_states_by_date.c.max_last_updated,
),
)
.group_by(States.entity_id)
.subquery()
).c.max_state_id,
) )
stmt += _ignore_domains_filter stmt += _ignore_domains_filter
if filters and filters.has_config: if filters and filters.has_config:
@ -772,7 +783,7 @@ def _get_rows_with_session(
# We have more than one entity to look at so we need to do a query on states # We have more than one entity to look at so we need to do a query on states
# since the last recorder run started. # since the last recorder run started.
if entity_ids: if entity_ids:
stmt = _get_states_for_entites_stmt( stmt = _get_states_for_entities_stmt(
schema_version, run.start, utc_point_in_time, entity_ids, no_attributes schema_version, run.start, utc_point_in_time, entity_ids, no_attributes
) )
else: else: