diff --git a/homeassistant/components/recorder/history/modern.py b/homeassistant/components/recorder/history/modern.py index 9159bbc6181..279ca9c9eea 100644 --- a/homeassistant/components/recorder/history/modern.py +++ b/homeassistant/components/recorder/history/modern.py @@ -27,8 +27,13 @@ from homeassistant.core import HomeAssistant, State, split_entity_id from homeassistant.helpers.recorder import get_instance import homeassistant.util.dt as dt_util -from ..const import LAST_REPORTED_SCHEMA_VERSION -from ..db_schema import SHARED_ATTR_OR_LEGACY_ATTRIBUTES, StateAttributes, States +from ..const import LAST_REPORTED_SCHEMA_VERSION, SupportedDialect +from ..db_schema import ( + SHARED_ATTR_OR_LEGACY_ATTRIBUTES, + StateAttributes, + States, + StatesMeta, +) from ..filters import Filters from ..models import ( LazyState, @@ -145,6 +150,7 @@ def _significant_states_stmt( no_attributes: bool, include_start_time_state: bool, run_start_ts: float | None, + lateral_join_for_start_time: bool, ) -> Select | CompoundSelect: """Query the database for significant state changes.""" include_last_changed = not significant_changes_only @@ -184,6 +190,7 @@ def _significant_states_stmt( metadata_ids, no_attributes, include_last_changed, + lateral_join_for_start_time, ).subquery(), no_attributes, include_last_changed, @@ -254,6 +261,7 @@ def get_significant_states_with_session( start_time_ts = start_time.timestamp() end_time_ts = datetime_to_timestamp_or_none(end_time) single_metadata_id = metadata_ids[0] if len(metadata_ids) == 1 else None + lateral_join_for_start_time = instance.dialect_name == SupportedDialect.POSTGRESQL stmt = lambda_stmt( lambda: _significant_states_stmt( start_time_ts, @@ -265,6 +273,7 @@ def get_significant_states_with_session( no_attributes, include_start_time_state, run_start_ts, + lateral_join_for_start_time, ), track_on=[ bool(single_metadata_id), @@ -556,30 +565,61 @@ def _get_start_time_state_for_entities_stmt( metadata_ids: list[int], no_attributes: bool, include_last_changed: bool, + lateral_join_for_start_time: bool, ) -> Select: """Baked query to get states for specific entities.""" # We got an include-list of entities, accelerate the query by filtering already # in the inner and the outer query. + if lateral_join_for_start_time: + # PostgreSQL does not support index skip scan/loose index scan + # https://wiki.postgresql.org/wiki/Loose_indexscan + # so we need to do a lateral join to get the max last_updated_ts + # for each metadata_id as a group-by is too slow. + # https://github.com/home-assistant/core/issues/132865 + max_metadata_id = StatesMeta.metadata_id.label("max_metadata_id") + max_last_updated = ( + select(func.max(States.last_updated_ts)) + .where( + (States.metadata_id == max_metadata_id) + & (States.last_updated_ts >= run_start_ts) + & (States.last_updated_ts < epoch_time) + ) + .subquery() + .lateral() + ) + most_recent_states_for_entities_by_date = ( + select(max_metadata_id, max_last_updated.c[0].label("max_last_updated")) + .select_from(StatesMeta) + .join( + max_last_updated, + StatesMeta.metadata_id == max_metadata_id, + ) + .where(StatesMeta.metadata_id.in_(metadata_ids)) + ).subquery() + else: + # Simple group-by for MySQL and SQLite, must use less + # than 1000 metadata_ids in the IN clause for MySQL + # or it will optimize poorly. + most_recent_states_for_entities_by_date = ( + select( + States.metadata_id.label("max_metadata_id"), + func.max(States.last_updated_ts).label("max_last_updated"), + ) + .filter( + (States.last_updated_ts >= run_start_ts) + & (States.last_updated_ts < epoch_time) + & States.metadata_id.in_(metadata_ids) + ) + .group_by(States.metadata_id) + .subquery() + ) + stmt = ( _stmt_and_join_attributes_for_start_state( no_attributes, include_last_changed, False ) .join( - ( - most_recent_states_for_entities_by_date := ( - select( - States.metadata_id.label("max_metadata_id"), - func.max(States.last_updated_ts).label("max_last_updated"), - ) - .filter( - (States.last_updated_ts >= run_start_ts) - & (States.last_updated_ts < epoch_time) - & States.metadata_id.in_(metadata_ids) - ) - .group_by(States.metadata_id) - .subquery() - ) - ), + most_recent_states_for_entities_by_date, and_( States.metadata_id == most_recent_states_for_entities_by_date.c.max_metadata_id, @@ -621,6 +661,7 @@ def _get_start_time_state_stmt( metadata_ids: list[int], no_attributes: bool, include_last_changed: bool, + lateral_join_for_start_time: bool, ) -> Select: """Return the states at a specific point in time.""" if single_metadata_id: @@ -641,6 +682,7 @@ def _get_start_time_state_stmt( metadata_ids, no_attributes, include_last_changed, + lateral_join_for_start_time, ) diff --git a/homeassistant/components/recorder/statistics.py b/homeassistant/components/recorder/statistics.py index 3f1d5b981e3..9e47ca43c5b 100644 --- a/homeassistant/components/recorder/statistics.py +++ b/homeassistant/components/recorder/statistics.py @@ -63,6 +63,7 @@ from .db_schema import ( STATISTICS_TABLES, Statistics, StatisticsBase, + StatisticsMeta, StatisticsRuns, StatisticsShortTerm, ) @@ -1669,6 +1670,7 @@ def _augment_result_with_change( drop_sum = "sum" not in _types prev_sums = {} if tmp := _statistics_at_time( + hass, session, {metadata[statistic_id][0] for statistic_id in result}, table, @@ -2032,22 +2034,50 @@ def _generate_statistics_at_time_stmt( metadata_ids: set[int], start_time_ts: float, types: set[Literal["last_reset", "max", "mean", "min", "state", "sum"]], + lateral_join_for_start_time: bool, ) -> StatementLambdaElement: """Create the statement for finding the statistics for a given time.""" stmt = _generate_select_columns_for_types_stmt(table, types) - stmt += lambda q: q.join( - ( - most_recent_statistic_ids := ( - select( - func.max(table.start_ts).label("max_start_ts"), - table.metadata_id.label("max_metadata_id"), - ) - .filter(table.start_ts < start_time_ts) - .filter(table.metadata_id.in_(metadata_ids)) - .group_by(table.metadata_id) - .subquery() + if lateral_join_for_start_time: + # PostgreSQL does not support index skip scan/loose index scan + # https://wiki.postgresql.org/wiki/Loose_indexscan + # so we need to do a lateral join to get the max max_start_ts + # for each metadata_id as a group-by is too slow. + # https://github.com/home-assistant/core/issues/132865 + max_metadata_id = StatisticsMeta.id.label("max_metadata_id") + max_start = ( + select(func.max(table.start_ts)) + .filter(table.metadata_id == max_metadata_id) + .filter(table.start_ts < start_time_ts) + .filter(table.metadata_id.in_(metadata_ids)) + .subquery() + .lateral() + ) + most_recent_statistic_ids = ( + select(max_metadata_id, max_start.c[0].label("max_start_ts")) + .select_from(StatisticsMeta) + .join( + max_start, + StatisticsMeta.id == max_metadata_id, ) - ), + .where(StatisticsMeta.id.in_(metadata_ids)) + ).subquery() + else: + # Simple group-by for MySQL and SQLite, must use less + # than 1000 metadata_ids in the IN clause for MySQL + # or it will optimize poorly. + most_recent_statistic_ids = ( + select( + func.max(table.start_ts).label("max_start_ts"), + table.metadata_id.label("max_metadata_id"), + ) + .filter(table.start_ts < start_time_ts) + .filter(table.metadata_id.in_(metadata_ids)) + .group_by(table.metadata_id) + .subquery() + ) + stmt += lambda q: q.join( + most_recent_statistic_ids, and_( table.start_ts == most_recent_statistic_ids.c.max_start_ts, table.metadata_id == most_recent_statistic_ids.c.max_metadata_id, @@ -2057,6 +2087,7 @@ def _generate_statistics_at_time_stmt( def _statistics_at_time( + hass: HomeAssistant, session: Session, metadata_ids: set[int], table: type[StatisticsBase], @@ -2065,7 +2096,11 @@ def _statistics_at_time( ) -> Sequence[Row] | None: """Return last known statistics, earlier than start_time, for the metadata_ids.""" start_time_ts = start_time.timestamp() - stmt = _generate_statistics_at_time_stmt(table, metadata_ids, start_time_ts, types) + dialect_name = get_instance(hass).dialect_name + lateral_join_for_start_time = dialect_name == SupportedDialect.POSTGRESQL + stmt = _generate_statistics_at_time_stmt( + table, metadata_ids, start_time_ts, types, lateral_join_for_start_time + ) return cast(Sequence[Row], execute_stmt_lambda_element(session, stmt)) diff --git a/tests/components/recorder/test_history.py b/tests/components/recorder/test_history.py index 28b8275247c..eea4605039b 100644 --- a/tests/components/recorder/test_history.py +++ b/tests/components/recorder/test_history.py @@ -1014,3 +1014,127 @@ async def test_get_last_state_changes_with_non_existent_entity_ids_returns_empty ) -> None: """Test get_last_state_changes returns an empty dict when entities not in the db.""" assert history.get_last_state_changes(hass, 1, "nonexistent.entity") == {} + + +@pytest.mark.skip_on_db_engine(["sqlite", "mysql"]) +@pytest.mark.usefixtures("skip_by_db_engine") +@pytest.mark.usefixtures("recorder_db_url") +async def test_get_significant_states_with_session_uses_lateral_with_postgresql( + hass: HomeAssistant, caplog: pytest.LogCaptureFixture +) -> None: + """Test get_significant_states_with_session uses the lateral path with PostgreSQL.""" + entity_id = "media_player.test" + hass.states.async_set("any.other", "on") + await async_wait_recording_done(hass) + hass.states.async_set(entity_id, "off") + + def set_state(state): + """Set the state.""" + hass.states.async_set(entity_id, state, {"any": 1}) + return hass.states.get(entity_id) + + start = dt_util.utcnow().replace(microsecond=0) + point = start + timedelta(seconds=1) + point2 = start + timedelta(seconds=1, microseconds=100) + point3 = start + timedelta(seconds=1, microseconds=200) + end = point + timedelta(seconds=1, microseconds=400) + + with freeze_time(start) as freezer: + set_state("idle") + set_state("YouTube") + + freezer.move_to(point) + states = [set_state("idle")] + + freezer.move_to(point2) + states.append(set_state("Netflix")) + + freezer.move_to(point3) + states.append(set_state("Plex")) + + freezer.move_to(end) + set_state("Netflix") + set_state("Plex") + await async_wait_recording_done(hass) + + start_time = point2 + timedelta(microseconds=10) + hist = history.get_significant_states( + hass=hass, + start_time=start_time, # Pick a point where we will generate a start time state + end_time=end, + entity_ids=[entity_id, "any.other"], + include_start_time_state=True, + ) + assert len(hist[entity_id]) == 2 + + sqlalchemy_logs = "".join( + [ + record.getMessage() + for record in caplog.records + if record.name.startswith("sqlalchemy.engine") + ] + ) + # We can't patch inside the lambda so we have to check the logs + assert "JOIN LATERAL" in sqlalchemy_logs + + +@pytest.mark.skip_on_db_engine(["postgresql"]) +@pytest.mark.usefixtures("skip_by_db_engine") +@pytest.mark.usefixtures("recorder_db_url") +async def test_get_significant_states_with_session_uses_non_lateral_without_postgresql( + hass: HomeAssistant, caplog: pytest.LogCaptureFixture +) -> None: + """Test get_significant_states_with_session does not use a the lateral path without PostgreSQL.""" + entity_id = "media_player.test" + hass.states.async_set("any.other", "on") + await async_wait_recording_done(hass) + hass.states.async_set(entity_id, "off") + + def set_state(state): + """Set the state.""" + hass.states.async_set(entity_id, state, {"any": 1}) + return hass.states.get(entity_id) + + start = dt_util.utcnow().replace(microsecond=0) + point = start + timedelta(seconds=1) + point2 = start + timedelta(seconds=1, microseconds=100) + point3 = start + timedelta(seconds=1, microseconds=200) + end = point + timedelta(seconds=1, microseconds=400) + + with freeze_time(start) as freezer: + set_state("idle") + set_state("YouTube") + + freezer.move_to(point) + states = [set_state("idle")] + + freezer.move_to(point2) + states.append(set_state("Netflix")) + + freezer.move_to(point3) + states.append(set_state("Plex")) + + freezer.move_to(end) + set_state("Netflix") + set_state("Plex") + await async_wait_recording_done(hass) + + start_time = point2 + timedelta(microseconds=10) + hist = history.get_significant_states( + hass=hass, + start_time=start_time, # Pick a point where we will generate a start time state + end_time=end, + entity_ids=[entity_id, "any.other"], + include_start_time_state=True, + ) + assert len(hist[entity_id]) == 2 + + sqlalchemy_logs = "".join( + [ + record.getMessage() + for record in caplog.records + if record.name.startswith("sqlalchemy.engine") + ] + ) + # We can't patch inside the lambda so we have to check the logs + assert "JOIN LATERAL" not in sqlalchemy_logs diff --git a/tests/components/recorder/test_statistics.py b/tests/components/recorder/test_statistics.py index 6b1e1a655db..55029c3eacf 100644 --- a/tests/components/recorder/test_statistics.py +++ b/tests/components/recorder/test_statistics.py @@ -1914,20 +1914,185 @@ def test_cache_key_for_generate_max_mean_min_statistic_in_sub_period_stmt() -> N assert cache_key_1 != cache_key_3 -def test_cache_key_for_generate_statistics_at_time_stmt() -> None: +@pytest.mark.parametrize("lateral_join_for_start_time", [True, False]) +def test_cache_key_for_generate_statistics_at_time_stmt( + lateral_join_for_start_time: bool, +) -> None: """Test cache key for _generate_statistics_at_time_stmt.""" - stmt = _generate_statistics_at_time_stmt(StatisticsShortTerm, {0}, 0.0, set()) + stmt = _generate_statistics_at_time_stmt( + StatisticsShortTerm, {0}, 0.0, set(), lateral_join_for_start_time + ) cache_key_1 = stmt._generate_cache_key() - stmt2 = _generate_statistics_at_time_stmt(StatisticsShortTerm, {0}, 0.0, set()) + stmt2 = _generate_statistics_at_time_stmt( + StatisticsShortTerm, {0}, 0.0, set(), lateral_join_for_start_time + ) cache_key_2 = stmt2._generate_cache_key() assert cache_key_1 == cache_key_2 stmt3 = _generate_statistics_at_time_stmt( - StatisticsShortTerm, {0}, 0.0, {"sum", "mean"} + StatisticsShortTerm, + {0}, + 0.0, + {"sum", "mean"}, + lateral_join_for_start_time, ) cache_key_3 = stmt3._generate_cache_key() assert cache_key_1 != cache_key_3 +@pytest.mark.skip_on_db_engine(["sqlite", "mysql"]) +@pytest.mark.usefixtures("skip_by_db_engine") +@pytest.mark.usefixtures("recorder_db_url") +@pytest.mark.freeze_time("2022-10-01 00:00:00+00:00") +async def test_statistics_at_time_uses_lateral_query_with_postgresql( + hass: HomeAssistant, + setup_recorder: None, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test statistics_at_time uses a lateral query with PostgreSQL.""" + await async_wait_recording_done(hass) + assert "Compiling statistics for" not in caplog.text + assert "Statistics already compiled" not in caplog.text + + zero = dt_util.utcnow() + period1 = dt_util.as_utc(dt_util.parse_datetime("2023-05-08 00:00:00")) + period2 = dt_util.as_utc(dt_util.parse_datetime("2023-05-08 01:00:00")) + period3 = dt_util.as_utc(dt_util.parse_datetime("2023-05-08 02:00:00")) + period4 = dt_util.as_utc(dt_util.parse_datetime("2023-05-08 03:00:00")) + + external_statistics = ( + { + "start": period1, + "last_reset": None, + "state": 0, + "sum": 2, + }, + { + "start": period2, + "last_reset": None, + "state": 1, + "sum": 3, + }, + { + "start": period3, + "last_reset": None, + "state": 2, + "sum": 5, + }, + { + "start": period4, + "last_reset": None, + "state": 3, + "sum": 8, + }, + ) + external_metadata = { + "has_mean": False, + "has_sum": True, + "name": "Total imported energy", + "source": "recorder", + "statistic_id": "sensor.total_energy_import", + "unit_of_measurement": "kWh", + } + + async_import_statistics(hass, external_metadata, external_statistics) + await async_wait_recording_done(hass) + # Get change from far in the past + stats = statistics_during_period( + hass, + zero, + period="hour", + statistic_ids={"sensor.total_energy_import"}, + types={"change", "sum"}, + ) + assert stats + sqlalchemy_logs = "".join( + [ + record.getMessage() + for record in caplog.records + if record.name.startswith("sqlalchemy.engine") + ] + ) + # We can't patch inside the lambda so we have to check the logs + assert "JOIN LATERAL" in sqlalchemy_logs + + +@pytest.mark.skip_on_db_engine(["postgresql"]) +@pytest.mark.usefixtures("skip_by_db_engine") +@pytest.mark.usefixtures("recorder_db_url") +@pytest.mark.freeze_time("2022-10-01 00:00:00+00:00") +async def test_statistics_at_time_uses_non_lateral_query_without_postgresql( + hass: HomeAssistant, + setup_recorder: None, + caplog: pytest.LogCaptureFixture, +) -> None: + """Test statistics_at_time does not use a lateral query without PostgreSQL.""" + await async_wait_recording_done(hass) + assert "Compiling statistics for" not in caplog.text + assert "Statistics already compiled" not in caplog.text + + zero = dt_util.utcnow() + period1 = dt_util.as_utc(dt_util.parse_datetime("2023-05-08 00:00:00")) + period2 = dt_util.as_utc(dt_util.parse_datetime("2023-05-08 01:00:00")) + period3 = dt_util.as_utc(dt_util.parse_datetime("2023-05-08 02:00:00")) + period4 = dt_util.as_utc(dt_util.parse_datetime("2023-05-08 03:00:00")) + + external_statistics = ( + { + "start": period1, + "last_reset": None, + "state": 0, + "sum": 2, + }, + { + "start": period2, + "last_reset": None, + "state": 1, + "sum": 3, + }, + { + "start": period3, + "last_reset": None, + "state": 2, + "sum": 5, + }, + { + "start": period4, + "last_reset": None, + "state": 3, + "sum": 8, + }, + ) + external_metadata = { + "has_mean": False, + "has_sum": True, + "name": "Total imported energy", + "source": "recorder", + "statistic_id": "sensor.total_energy_import", + "unit_of_measurement": "kWh", + } + + async_import_statistics(hass, external_metadata, external_statistics) + await async_wait_recording_done(hass) + # Get change from far in the past + stats = statistics_during_period( + hass, + zero, + period="hour", + statistic_ids={"sensor.total_energy_import"}, + types={"change", "sum"}, + ) + assert stats + sqlalchemy_logs = "".join( + [ + record.getMessage() + for record in caplog.records + if record.name.startswith("sqlalchemy.engine") + ] + ) + # We can't patch inside the lambda so we have to check the logs + assert "JOIN LATERAL" not in sqlalchemy_logs + + @pytest.mark.parametrize("timezone", ["America/Regina", "Europe/Vienna", "UTC"]) @pytest.mark.freeze_time("2022-10-01 00:00:00+00:00") async def test_change(