Optimize start time state queries for PostgreSQL (#133228)

This commit is contained in:
J. Nick Koston 2024-12-18 19:41:53 -10:00 committed by GitHub
parent 3fe08a7223
commit 99698ef95d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 400 additions and 34 deletions

View File

@ -27,8 +27,13 @@ from homeassistant.core import HomeAssistant, State, split_entity_id
from homeassistant.helpers.recorder import get_instance
import homeassistant.util.dt as dt_util
from ..const import LAST_REPORTED_SCHEMA_VERSION
from ..db_schema import SHARED_ATTR_OR_LEGACY_ATTRIBUTES, StateAttributes, States
from ..const import LAST_REPORTED_SCHEMA_VERSION, SupportedDialect
from ..db_schema import (
SHARED_ATTR_OR_LEGACY_ATTRIBUTES,
StateAttributes,
States,
StatesMeta,
)
from ..filters import Filters
from ..models import (
LazyState,
@ -145,6 +150,7 @@ def _significant_states_stmt(
no_attributes: bool,
include_start_time_state: bool,
run_start_ts: float | None,
lateral_join_for_start_time: bool,
) -> Select | CompoundSelect:
"""Query the database for significant state changes."""
include_last_changed = not significant_changes_only
@ -184,6 +190,7 @@ def _significant_states_stmt(
metadata_ids,
no_attributes,
include_last_changed,
lateral_join_for_start_time,
).subquery(),
no_attributes,
include_last_changed,
@ -254,6 +261,7 @@ def get_significant_states_with_session(
start_time_ts = start_time.timestamp()
end_time_ts = datetime_to_timestamp_or_none(end_time)
single_metadata_id = metadata_ids[0] if len(metadata_ids) == 1 else None
lateral_join_for_start_time = instance.dialect_name == SupportedDialect.POSTGRESQL
stmt = lambda_stmt(
lambda: _significant_states_stmt(
start_time_ts,
@ -265,6 +273,7 @@ def get_significant_states_with_session(
no_attributes,
include_start_time_state,
run_start_ts,
lateral_join_for_start_time,
),
track_on=[
bool(single_metadata_id),
@ -556,30 +565,61 @@ def _get_start_time_state_for_entities_stmt(
metadata_ids: list[int],
no_attributes: bool,
include_last_changed: bool,
lateral_join_for_start_time: bool,
) -> Select:
"""Baked query to get states for specific entities."""
# We got an include-list of entities, accelerate the query by filtering already
# in the inner and the outer query.
if lateral_join_for_start_time:
# PostgreSQL does not support index skip scan/loose index scan
# https://wiki.postgresql.org/wiki/Loose_indexscan
# so we need to do a lateral join to get the max last_updated_ts
# for each metadata_id as a group-by is too slow.
# https://github.com/home-assistant/core/issues/132865
max_metadata_id = StatesMeta.metadata_id.label("max_metadata_id")
max_last_updated = (
select(func.max(States.last_updated_ts))
.where(
(States.metadata_id == max_metadata_id)
& (States.last_updated_ts >= run_start_ts)
& (States.last_updated_ts < epoch_time)
)
.subquery()
.lateral()
)
most_recent_states_for_entities_by_date = (
select(max_metadata_id, max_last_updated.c[0].label("max_last_updated"))
.select_from(StatesMeta)
.join(
max_last_updated,
StatesMeta.metadata_id == max_metadata_id,
)
.where(StatesMeta.metadata_id.in_(metadata_ids))
).subquery()
else:
# Simple group-by for MySQL and SQLite, must use less
# than 1000 metadata_ids in the IN clause for MySQL
# or it will optimize poorly.
most_recent_states_for_entities_by_date = (
select(
States.metadata_id.label("max_metadata_id"),
func.max(States.last_updated_ts).label("max_last_updated"),
)
.filter(
(States.last_updated_ts >= run_start_ts)
& (States.last_updated_ts < epoch_time)
& States.metadata_id.in_(metadata_ids)
)
.group_by(States.metadata_id)
.subquery()
)
stmt = (
_stmt_and_join_attributes_for_start_state(
no_attributes, include_last_changed, False
)
.join(
(
most_recent_states_for_entities_by_date := (
select(
States.metadata_id.label("max_metadata_id"),
func.max(States.last_updated_ts).label("max_last_updated"),
)
.filter(
(States.last_updated_ts >= run_start_ts)
& (States.last_updated_ts < epoch_time)
& States.metadata_id.in_(metadata_ids)
)
.group_by(States.metadata_id)
.subquery()
)
),
most_recent_states_for_entities_by_date,
and_(
States.metadata_id
== most_recent_states_for_entities_by_date.c.max_metadata_id,
@ -621,6 +661,7 @@ def _get_start_time_state_stmt(
metadata_ids: list[int],
no_attributes: bool,
include_last_changed: bool,
lateral_join_for_start_time: bool,
) -> Select:
"""Return the states at a specific point in time."""
if single_metadata_id:
@ -641,6 +682,7 @@ def _get_start_time_state_stmt(
metadata_ids,
no_attributes,
include_last_changed,
lateral_join_for_start_time,
)

View File

@ -63,6 +63,7 @@ from .db_schema import (
STATISTICS_TABLES,
Statistics,
StatisticsBase,
StatisticsMeta,
StatisticsRuns,
StatisticsShortTerm,
)
@ -1669,6 +1670,7 @@ def _augment_result_with_change(
drop_sum = "sum" not in _types
prev_sums = {}
if tmp := _statistics_at_time(
hass,
session,
{metadata[statistic_id][0] for statistic_id in result},
table,
@ -2032,22 +2034,50 @@ def _generate_statistics_at_time_stmt(
metadata_ids: set[int],
start_time_ts: float,
types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]],
lateral_join_for_start_time: bool,
) -> StatementLambdaElement:
"""Create the statement for finding the statistics for a given time."""
stmt = _generate_select_columns_for_types_stmt(table, types)
stmt += lambda q: q.join(
(
most_recent_statistic_ids := (
select(
func.max(table.start_ts).label("max_start_ts"),
table.metadata_id.label("max_metadata_id"),
)
.filter(table.start_ts < start_time_ts)
.filter(table.metadata_id.in_(metadata_ids))
.group_by(table.metadata_id)
.subquery()
if lateral_join_for_start_time:
# PostgreSQL does not support index skip scan/loose index scan
# https://wiki.postgresql.org/wiki/Loose_indexscan
# so we need to do a lateral join to get the max max_start_ts
# for each metadata_id as a group-by is too slow.
# https://github.com/home-assistant/core/issues/132865
max_metadata_id = StatisticsMeta.id.label("max_metadata_id")
max_start = (
select(func.max(table.start_ts))
.filter(table.metadata_id == max_metadata_id)
.filter(table.start_ts < start_time_ts)
.filter(table.metadata_id.in_(metadata_ids))
.subquery()
.lateral()
)
most_recent_statistic_ids = (
select(max_metadata_id, max_start.c[0].label("max_start_ts"))
.select_from(StatisticsMeta)
.join(
max_start,
StatisticsMeta.id == max_metadata_id,
)
),
.where(StatisticsMeta.id.in_(metadata_ids))
).subquery()
else:
# Simple group-by for MySQL and SQLite, must use less
# than 1000 metadata_ids in the IN clause for MySQL
# or it will optimize poorly.
most_recent_statistic_ids = (
select(
func.max(table.start_ts).label("max_start_ts"),
table.metadata_id.label("max_metadata_id"),
)
.filter(table.start_ts < start_time_ts)
.filter(table.metadata_id.in_(metadata_ids))
.group_by(table.metadata_id)
.subquery()
)
stmt += lambda q: q.join(
most_recent_statistic_ids,
and_(
table.start_ts == most_recent_statistic_ids.c.max_start_ts,
table.metadata_id == most_recent_statistic_ids.c.max_metadata_id,
@ -2057,6 +2087,7 @@ def _generate_statistics_at_time_stmt(
def _statistics_at_time(
hass: HomeAssistant,
session: Session,
metadata_ids: set[int],
table: type[StatisticsBase],
@ -2065,7 +2096,11 @@ def _statistics_at_time(
) -> Sequence[Row] | None:
"""Return last known statistics, earlier than start_time, for the metadata_ids."""
start_time_ts = start_time.timestamp()
stmt = _generate_statistics_at_time_stmt(table, metadata_ids, start_time_ts, types)
dialect_name = get_instance(hass).dialect_name
lateral_join_for_start_time = dialect_name == SupportedDialect.POSTGRESQL
stmt = _generate_statistics_at_time_stmt(
table, metadata_ids, start_time_ts, types, lateral_join_for_start_time
)
return cast(Sequence[Row], execute_stmt_lambda_element(session, stmt))

View File

@ -1014,3 +1014,127 @@ async def test_get_last_state_changes_with_non_existent_entity_ids_returns_empty
) -> None:
"""Test get_last_state_changes returns an empty dict when entities not in the db."""
assert history.get_last_state_changes(hass, 1, "nonexistent.entity") == {}
@pytest.mark.skip_on_db_engine(["sqlite", "mysql"])
@pytest.mark.usefixtures("skip_by_db_engine")
@pytest.mark.usefixtures("recorder_db_url")
async def test_get_significant_states_with_session_uses_lateral_with_postgresql(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
"""Test get_significant_states_with_session uses the lateral path with PostgreSQL."""
entity_id = "media_player.test"
hass.states.async_set("any.other", "on")
await async_wait_recording_done(hass)
hass.states.async_set(entity_id, "off")
def set_state(state):
"""Set the state."""
hass.states.async_set(entity_id, state, {"any": 1})
return hass.states.get(entity_id)
start = dt_util.utcnow().replace(microsecond=0)
point = start + timedelta(seconds=1)
point2 = start + timedelta(seconds=1, microseconds=100)
point3 = start + timedelta(seconds=1, microseconds=200)
end = point + timedelta(seconds=1, microseconds=400)
with freeze_time(start) as freezer:
set_state("idle")
set_state("YouTube")
freezer.move_to(point)
states = [set_state("idle")]
freezer.move_to(point2)
states.append(set_state("Netflix"))
freezer.move_to(point3)
states.append(set_state("Plex"))
freezer.move_to(end)
set_state("Netflix")
set_state("Plex")
await async_wait_recording_done(hass)
start_time = point2 + timedelta(microseconds=10)
hist = history.get_significant_states(
hass=hass,
start_time=start_time, # Pick a point where we will generate a start time state
end_time=end,
entity_ids=[entity_id, "any.other"],
include_start_time_state=True,
)
assert len(hist[entity_id]) == 2
sqlalchemy_logs = "".join(
[
record.getMessage()
for record in caplog.records
if record.name.startswith("sqlalchemy.engine")
]
)
# We can't patch inside the lambda so we have to check the logs
assert "JOIN LATERAL" in sqlalchemy_logs
@pytest.mark.skip_on_db_engine(["postgresql"])
@pytest.mark.usefixtures("skip_by_db_engine")
@pytest.mark.usefixtures("recorder_db_url")
async def test_get_significant_states_with_session_uses_non_lateral_without_postgresql(
hass: HomeAssistant, caplog: pytest.LogCaptureFixture
) -> None:
"""Test get_significant_states_with_session does not use a the lateral path without PostgreSQL."""
entity_id = "media_player.test"
hass.states.async_set("any.other", "on")
await async_wait_recording_done(hass)
hass.states.async_set(entity_id, "off")
def set_state(state):
"""Set the state."""
hass.states.async_set(entity_id, state, {"any": 1})
return hass.states.get(entity_id)
start = dt_util.utcnow().replace(microsecond=0)
point = start + timedelta(seconds=1)
point2 = start + timedelta(seconds=1, microseconds=100)
point3 = start + timedelta(seconds=1, microseconds=200)
end = point + timedelta(seconds=1, microseconds=400)
with freeze_time(start) as freezer:
set_state("idle")
set_state("YouTube")
freezer.move_to(point)
states = [set_state("idle")]
freezer.move_to(point2)
states.append(set_state("Netflix"))
freezer.move_to(point3)
states.append(set_state("Plex"))
freezer.move_to(end)
set_state("Netflix")
set_state("Plex")
await async_wait_recording_done(hass)
start_time = point2 + timedelta(microseconds=10)
hist = history.get_significant_states(
hass=hass,
start_time=start_time, # Pick a point where we will generate a start time state
end_time=end,
entity_ids=[entity_id, "any.other"],
include_start_time_state=True,
)
assert len(hist[entity_id]) == 2
sqlalchemy_logs = "".join(
[
record.getMessage()
for record in caplog.records
if record.name.startswith("sqlalchemy.engine")
]
)
# We can't patch inside the lambda so we have to check the logs
assert "JOIN LATERAL" not in sqlalchemy_logs

View File

@ -1914,20 +1914,185 @@ def test_cache_key_for_generate_max_mean_min_statistic_in_sub_period_stmt() -> N
assert cache_key_1 != cache_key_3
def test_cache_key_for_generate_statistics_at_time_stmt() -> None:
@pytest.mark.parametrize("lateral_join_for_start_time", [True, False])
def test_cache_key_for_generate_statistics_at_time_stmt(
lateral_join_for_start_time: bool,
) -> None:
"""Test cache key for _generate_statistics_at_time_stmt."""
stmt = _generate_statistics_at_time_stmt(StatisticsShortTerm, {0}, 0.0, set())
stmt = _generate_statistics_at_time_stmt(
StatisticsShortTerm, {0}, 0.0, set(), lateral_join_for_start_time
)
cache_key_1 = stmt._generate_cache_key()
stmt2 = _generate_statistics_at_time_stmt(StatisticsShortTerm, {0}, 0.0, set())
stmt2 = _generate_statistics_at_time_stmt(
StatisticsShortTerm, {0}, 0.0, set(), lateral_join_for_start_time
)
cache_key_2 = stmt2._generate_cache_key()
assert cache_key_1 == cache_key_2
stmt3 = _generate_statistics_at_time_stmt(
StatisticsShortTerm, {0}, 0.0, {"sum", "mean"}
StatisticsShortTerm,
{0},
0.0,
{"sum", "mean"},
lateral_join_for_start_time,
)
cache_key_3 = stmt3._generate_cache_key()
assert cache_key_1 != cache_key_3
@pytest.mark.skip_on_db_engine(["sqlite", "mysql"])
@pytest.mark.usefixtures("skip_by_db_engine")
@pytest.mark.usefixtures("recorder_db_url")
@pytest.mark.freeze_time("2022-10-01 00:00:00+00:00")
async def test_statistics_at_time_uses_lateral_query_with_postgresql(
hass: HomeAssistant,
setup_recorder: None,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test statistics_at_time uses a lateral query with PostgreSQL."""
await async_wait_recording_done(hass)
assert "Compiling statistics for" not in caplog.text
assert "Statistics already compiled" not in caplog.text
zero = dt_util.utcnow()
period1 = dt_util.as_utc(dt_util.parse_datetime("2023-05-08 00:00:00"))
period2 = dt_util.as_utc(dt_util.parse_datetime("2023-05-08 01:00:00"))
period3 = dt_util.as_utc(dt_util.parse_datetime("2023-05-08 02:00:00"))
period4 = dt_util.as_utc(dt_util.parse_datetime("2023-05-08 03:00:00"))
external_statistics = (
{
"start": period1,
"last_reset": None,
"state": 0,
"sum": 2,
},
{
"start": period2,
"last_reset": None,
"state": 1,
"sum": 3,
},
{
"start": period3,
"last_reset": None,
"state": 2,
"sum": 5,
},
{
"start": period4,
"last_reset": None,
"state": 3,
"sum": 8,
},
)
external_metadata = {
"has_mean": False,
"has_sum": True,
"name": "Total imported energy",
"source": "recorder",
"statistic_id": "sensor.total_energy_import",
"unit_of_measurement": "kWh",
}
async_import_statistics(hass, external_metadata, external_statistics)
await async_wait_recording_done(hass)
# Get change from far in the past
stats = statistics_during_period(
hass,
zero,
period="hour",
statistic_ids={"sensor.total_energy_import"},
types={"change", "sum"},
)
assert stats
sqlalchemy_logs = "".join(
[
record.getMessage()
for record in caplog.records
if record.name.startswith("sqlalchemy.engine")
]
)
# We can't patch inside the lambda so we have to check the logs
assert "JOIN LATERAL" in sqlalchemy_logs
@pytest.mark.skip_on_db_engine(["postgresql"])
@pytest.mark.usefixtures("skip_by_db_engine")
@pytest.mark.usefixtures("recorder_db_url")
@pytest.mark.freeze_time("2022-10-01 00:00:00+00:00")
async def test_statistics_at_time_uses_non_lateral_query_without_postgresql(
hass: HomeAssistant,
setup_recorder: None,
caplog: pytest.LogCaptureFixture,
) -> None:
"""Test statistics_at_time does not use a lateral query without PostgreSQL."""
await async_wait_recording_done(hass)
assert "Compiling statistics for" not in caplog.text
assert "Statistics already compiled" not in caplog.text
zero = dt_util.utcnow()
period1 = dt_util.as_utc(dt_util.parse_datetime("2023-05-08 00:00:00"))
period2 = dt_util.as_utc(dt_util.parse_datetime("2023-05-08 01:00:00"))
period3 = dt_util.as_utc(dt_util.parse_datetime("2023-05-08 02:00:00"))
period4 = dt_util.as_utc(dt_util.parse_datetime("2023-05-08 03:00:00"))
external_statistics = (
{
"start": period1,
"last_reset": None,
"state": 0,
"sum": 2,
},
{
"start": period2,
"last_reset": None,
"state": 1,
"sum": 3,
},
{
"start": period3,
"last_reset": None,
"state": 2,
"sum": 5,
},
{
"start": period4,
"last_reset": None,
"state": 3,
"sum": 8,
},
)
external_metadata = {
"has_mean": False,
"has_sum": True,
"name": "Total imported energy",
"source": "recorder",
"statistic_id": "sensor.total_energy_import",
"unit_of_measurement": "kWh",
}
async_import_statistics(hass, external_metadata, external_statistics)
await async_wait_recording_done(hass)
# Get change from far in the past
stats = statistics_during_period(
hass,
zero,
period="hour",
statistic_ids={"sensor.total_energy_import"},
types={"change", "sum"},
)
assert stats
sqlalchemy_logs = "".join(
[
record.getMessage()
for record in caplog.records
if record.name.startswith("sqlalchemy.engine")
]
)
# We can't patch inside the lambda so we have to check the logs
assert "JOIN LATERAL" not in sqlalchemy_logs
@pytest.mark.parametrize("timezone", ["America/Regina", "Europe/Vienna", "UTC"])
@pytest.mark.freeze_time("2022-10-01 00:00:00+00:00")
async def test_change(