diff --git a/homeassistant/components/logbook/__init__.py b/homeassistant/components/logbook/__init__.py index 53063d36fc0..a877cba4ff2 100644 --- a/homeassistant/components/logbook/__init__.py +++ b/homeassistant/components/logbook/__init__.py @@ -11,11 +11,13 @@ from typing import Any, cast from aiohttp import web import sqlalchemy +from sqlalchemy import lambda_stmt, select from sqlalchemy.engine.row import Row from sqlalchemy.orm import aliased from sqlalchemy.orm.query import Query -from sqlalchemy.orm.session import Session from sqlalchemy.sql.expression import literal +from sqlalchemy.sql.lambdas import StatementLambdaElement +from sqlalchemy.sql.selectable import Select import voluptuous as vol from homeassistant.components import frontend @@ -85,8 +87,6 @@ CONTINUOUS_ENTITY_ID_LIKE = [f"{domain}.%" for domain in CONTINUOUS_DOMAINS] DOMAIN = "logbook" -GROUP_BY_MINUTES = 15 - EMPTY_JSON_OBJECT = "{}" UNIT_OF_MEASUREMENT_JSON = '"unit_of_measurement":' UNIT_OF_MEASUREMENT_JSON_LIKE = f"%{UNIT_OF_MEASUREMENT_JSON}%" @@ -435,70 +435,43 @@ def _get_events( def yield_rows(query: Query) -> Generator[Row, None, None]: """Yield Events that are not filtered away.""" - for row in query.yield_per(1000): + if entity_ids or context_id: + rows = query.all() + else: + rows = query.yield_per(1000) + for row in rows: context_lookup.setdefault(row.context_id, row) - if row.event_type != EVENT_CALL_SERVICE and ( - row.event_type == EVENT_STATE_CHANGED - or _keep_row(hass, row, entities_filter) + event_type = row.event_type + if event_type != EVENT_CALL_SERVICE and ( + event_type == EVENT_STATE_CHANGED + or _keep_row(hass, event_type, row, entities_filter) ): yield row if entity_ids is not None: entities_filter = generate_filter([], entity_ids, [], []) + event_types = [ + *ALL_EVENT_TYPES_EXCEPT_STATE_CHANGED, + *hass.data.get(DOMAIN, {}), + ] + entity_filter = None + if entity_ids is None and filters: + entity_filter = filters.entity_filter() # type: ignore[no-untyped-call] + stmt = _generate_logbook_query( + start_day, + end_day, + event_types, + entity_ids, + entity_filter, + entity_matches_only, + context_id, + ) with session_scope(hass=hass) as session: - old_state = aliased(States, name="old_state") - query: Query - query = _generate_events_query_without_states(session) - query = _apply_event_time_filter(query, start_day, end_day) - query = _apply_event_types_filter( - hass, query, ALL_EVENT_TYPES_EXCEPT_STATE_CHANGED - ) - - if entity_ids is not None: - if entity_matches_only: - # When entity_matches_only is provided, contexts and events that do not - # contain the entity_ids are not included in the logbook response. - query = _apply_event_entity_id_matchers(query, entity_ids) - query = query.outerjoin(EventData, (Events.data_id == EventData.data_id)) - query = query.union_all( - _generate_states_query( - session, start_day, end_day, old_state, entity_ids - ) - ) - else: - if context_id is not None: - query = query.filter(Events.context_id == context_id) - query = query.outerjoin(EventData, (Events.data_id == EventData.data_id)) - - states_query = _generate_states_query( - session, start_day, end_day, old_state, entity_ids - ) - unions: list[Query] = [] - if context_id is not None: - # Once all the old `state_changed` events - # are gone from the database remove the - # _generate_legacy_events_context_id_query - unions.append( - _generate_legacy_events_context_id_query( - session, context_id, start_day, end_day - ) - ) - states_query = states_query.outerjoin( - Events, (States.event_id == Events.event_id) - ) - states_query = states_query.filter(States.context_id == context_id) - elif filters: - states_query = states_query.filter(filters.entity_filter()) # type: ignore[no-untyped-call] - unions.append(states_query) - query = query.union_all(*unions) - - query = query.order_by(Events.time_fired) - return list( _humanify( hass, - yield_rows(query), + yield_rows(session.execute(stmt)), entity_name_cache, event_cache, context_augmenter, @@ -506,8 +479,72 @@ def _get_events( ) -def _generate_events_query_without_data(session: Session) -> Query: - return session.query( +def _generate_logbook_query( + start_day: dt, + end_day: dt, + event_types: list[str], + entity_ids: list[str] | None = None, + entity_filter: Any | None = None, + entity_matches_only: bool = False, + context_id: str | None = None, +) -> StatementLambdaElement: + """Generate a logbook query lambda_stmt.""" + stmt = lambda_stmt( + lambda: _generate_events_query_without_states() + .where((Events.time_fired > start_day) & (Events.time_fired < end_day)) + .where(Events.event_type.in_(event_types)) + .outerjoin(EventData, (Events.data_id == EventData.data_id)) + ) + if entity_ids is not None: + if entity_matches_only: + # When entity_matches_only is provided, contexts and events that do not + # contain the entity_ids are not included in the logbook response. + stmt.add_criteria( + lambda s: s.where(_apply_event_entity_id_matchers(entity_ids)), + track_on=entity_ids, + ) + stmt += lambda s: s.union_all( + _generate_states_query() + .filter((States.last_updated > start_day) & (States.last_updated < end_day)) + .where(States.entity_id.in_(entity_ids)) + ) + else: + if context_id is not None: + # Once all the old `state_changed` events + # are gone from the database remove the + # union_all(_generate_legacy_events_context_id_query()....) + stmt += lambda s: s.where(Events.context_id == context_id).union_all( + _generate_legacy_events_context_id_query() + .where((Events.time_fired > start_day) & (Events.time_fired < end_day)) + .where(Events.context_id == context_id), + _generate_states_query() + .where( + (States.last_updated > start_day) & (States.last_updated < end_day) + ) + .outerjoin(Events, (States.event_id == Events.event_id)) + .where(States.context_id == context_id), + ) + elif entity_filter is not None: + stmt += lambda s: s.union_all( + _generate_states_query() + .where( + (States.last_updated > start_day) & (States.last_updated < end_day) + ) + .where(entity_filter) + ) + else: + stmt += lambda s: s.union_all( + _generate_states_query().where( + (States.last_updated > start_day) & (States.last_updated < end_day) + ) + ) + + stmt += lambda s: s.order_by(Events.time_fired) + return stmt + + +def _generate_events_query_without_data() -> Select: + return select( literal(value=EVENT_STATE_CHANGED, type_=sqlalchemy.String).label("event_type"), literal(value=None, type_=sqlalchemy.Text).label("event_data"), States.last_changed.label("time_fired"), @@ -519,65 +556,48 @@ def _generate_events_query_without_data(session: Session) -> Query: ) -def _generate_legacy_events_context_id_query( - session: Session, - context_id: str, - start_day: dt, - end_day: dt, -) -> Query: +def _generate_legacy_events_context_id_query() -> Select: """Generate a legacy events context id query that also joins states.""" # This can be removed once we no longer have event_ids in the states table - legacy_context_id_query = session.query( - *EVENT_COLUMNS, - literal(value=None, type_=sqlalchemy.String).label("shared_data"), - States.state, - States.entity_id, - States.attributes, - StateAttributes.shared_attrs, - ) - legacy_context_id_query = _apply_event_time_filter( - legacy_context_id_query, start_day, end_day - ) return ( - legacy_context_id_query.filter(Events.context_id == context_id) + select( + *EVENT_COLUMNS, + literal(value=None, type_=sqlalchemy.String).label("shared_data"), + States.state, + States.entity_id, + States.attributes, + StateAttributes.shared_attrs, + ) .outerjoin(States, (Events.event_id == States.event_id)) - .filter(States.last_updated == States.last_changed) - .filter(_not_continuous_entity_matcher()) + .where(States.last_updated == States.last_changed) + .where(_not_continuous_entity_matcher()) .outerjoin( StateAttributes, (States.attributes_id == StateAttributes.attributes_id) ) ) -def _generate_events_query_without_states(session: Session) -> Query: - return session.query( +def _generate_events_query_without_states() -> Select: + return select( *EVENT_COLUMNS, EventData.shared_data.label("shared_data"), *EMPTY_STATE_COLUMNS ) -def _generate_states_query( - session: Session, - start_day: dt, - end_day: dt, - old_state: States, - entity_ids: Iterable[str] | None, -) -> Query: - query = ( - _generate_events_query_without_data(session) +def _generate_states_query() -> Select: + old_state = aliased(States, name="old_state") + return ( + _generate_events_query_without_data() .outerjoin(old_state, (States.old_state_id == old_state.state_id)) - .filter(_missing_state_matcher(old_state)) - .filter(_not_continuous_entity_matcher()) - .filter((States.last_updated > start_day) & (States.last_updated < end_day)) - .filter(States.last_updated == States.last_changed) - ) - if entity_ids: - query = query.filter(States.entity_id.in_(entity_ids)) - return query.outerjoin( - StateAttributes, (States.attributes_id == StateAttributes.attributes_id) + .where(_missing_state_matcher(old_state)) + .where(_not_continuous_entity_matcher()) + .where(States.last_updated == States.last_changed) + .outerjoin( + StateAttributes, (States.attributes_id == StateAttributes.attributes_id) + ) ) -def _missing_state_matcher(old_state: States) -> Any: +def _missing_state_matcher(old_state: States) -> sqlalchemy.and_: # The below removes state change events that do not have # and old_state or the old_state is missing (newly added entities) # or the new_state is missing (removed entities) @@ -588,7 +608,7 @@ def _missing_state_matcher(old_state: States) -> Any: ) -def _not_continuous_entity_matcher() -> Any: +def _not_continuous_entity_matcher() -> sqlalchemy.or_: """Match non continuous entities.""" return sqlalchemy.or_( _not_continuous_domain_matcher(), @@ -598,7 +618,7 @@ def _not_continuous_entity_matcher() -> Any: ) -def _not_continuous_domain_matcher() -> Any: +def _not_continuous_domain_matcher() -> sqlalchemy.and_: """Match not continuous domains.""" return sqlalchemy.and_( *[ @@ -608,7 +628,7 @@ def _not_continuous_domain_matcher() -> Any: ).self_group() -def _continuous_domain_matcher() -> Any: +def _continuous_domain_matcher() -> sqlalchemy.or_: """Match continuous domains.""" return sqlalchemy.or_( *[ @@ -625,37 +645,22 @@ def _not_uom_attributes_matcher() -> Any: ) | ~States.attributes.like(UNIT_OF_MEASUREMENT_JSON_LIKE) -def _apply_event_time_filter(events_query: Query, start_day: dt, end_day: dt) -> Query: - return events_query.filter( - (Events.time_fired > start_day) & (Events.time_fired < end_day) - ) - - -def _apply_event_types_filter( - hass: HomeAssistant, query: Query, event_types: list[str] -) -> Query: - return query.filter( - Events.event_type.in_(event_types + list(hass.data.get(DOMAIN, {}))) - ) - - -def _apply_event_entity_id_matchers( - events_query: Query, entity_ids: Iterable[str] -) -> Query: +def _apply_event_entity_id_matchers(entity_ids: Iterable[str]) -> sqlalchemy.or_: + """Create matchers for the entity_id in the event_data.""" ors = [] for entity_id in entity_ids: like = ENTITY_ID_JSON_TEMPLATE.format(entity_id) ors.append(Events.event_data.like(like)) ors.append(EventData.shared_data.like(like)) - return events_query.filter(sqlalchemy.or_(*ors)) + return sqlalchemy.or_(*ors) def _keep_row( hass: HomeAssistant, + event_type: str, row: Row, entities_filter: EntityFilter | Callable[[str], bool] | None = None, ) -> bool: - event_type = row.event_type if event_type in HOMEASSISTANT_EVENTS: return entities_filter is None or entities_filter(HA_DOMAIN_ENTITY_ID) diff --git a/tests/components/logbook/test_init.py b/tests/components/logbook/test_init.py index 4f961dcd2c1..76319ba5a6e 100644 --- a/tests/components/logbook/test_init.py +++ b/tests/components/logbook/test_init.py @@ -1390,6 +1390,59 @@ async def test_logbook_entity_matches_only(hass, hass_client, recorder_mock): assert json_dict[1]["context_user_id"] == "9400facee45711eaa9308bfd3d19e474" +async def test_logbook_entity_matches_only_multiple_calls( + hass, hass_client, recorder_mock +): + """Test the logbook view with a single entity and entity_matches_only called multiple times.""" + await async_setup_component(hass, "logbook", {}) + await async_setup_component(hass, "automation", {}) + + await async_recorder_block_till_done(hass) + await hass.async_block_till_done() + await hass.async_start() + await hass.async_block_till_done() + + for automation_id in range(5): + hass.bus.async_fire( + EVENT_AUTOMATION_TRIGGERED, + { + ATTR_NAME: f"Mock automation {automation_id}", + ATTR_ENTITY_ID: f"automation.mock_{automation_id}_automation", + }, + ) + await async_wait_recording_done(hass) + client = await hass_client() + + # Today time 00:00:00 + start = dt_util.utcnow().date() + start_date = datetime(start.year, start.month, start.day) + end_time = start + timedelta(hours=24) + + for automation_id in range(5): + # Test today entries with filter by end_time + response = await client.get( + f"/api/logbook/{start_date.isoformat()}?end_time={end_time}&entity=automation.mock_{automation_id}_automation&entity_matches_only" + ) + assert response.status == HTTPStatus.OK + json_dict = await response.json() + + assert len(json_dict) == 1 + assert ( + json_dict[0]["entity_id"] == f"automation.mock_{automation_id}_automation" + ) + + response = await client.get( + f"/api/logbook/{start_date.isoformat()}?end_time={end_time}&entity=automation.mock_0_automation,automation.mock_1_automation,automation.mock_2_automation&entity_matches_only" + ) + assert response.status == HTTPStatus.OK + json_dict = await response.json() + + assert len(json_dict) == 3 + assert json_dict[0]["entity_id"] == "automation.mock_0_automation" + assert json_dict[1]["entity_id"] == "automation.mock_1_automation" + assert json_dict[2]["entity_id"] == "automation.mock_2_automation" + + async def test_custom_log_entry_discoverable_via_entity_matches_only( hass, hass_client, recorder_mock ):