Convert logbook to use lambda_stmt (#71624)

This commit is contained in:
J. Nick Koston 2022-05-10 08:23:13 -05:00 committed by GitHub
parent 68c2b63ca1
commit 26177bd080
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 179 additions and 121 deletions

View File

@ -11,11 +11,13 @@ from typing import Any, cast
from aiohttp import web from aiohttp import web
import sqlalchemy import sqlalchemy
from sqlalchemy import lambda_stmt, select
from sqlalchemy.engine.row import Row from sqlalchemy.engine.row import Row
from sqlalchemy.orm import aliased from sqlalchemy.orm import aliased
from sqlalchemy.orm.query import Query from sqlalchemy.orm.query import Query
from sqlalchemy.orm.session import Session
from sqlalchemy.sql.expression import literal from sqlalchemy.sql.expression import literal
from sqlalchemy.sql.lambdas import StatementLambdaElement
from sqlalchemy.sql.selectable import Select
import voluptuous as vol import voluptuous as vol
from homeassistant.components import frontend from homeassistant.components import frontend
@ -85,8 +87,6 @@ CONTINUOUS_ENTITY_ID_LIKE = [f"{domain}.%" for domain in CONTINUOUS_DOMAINS]
DOMAIN = "logbook" DOMAIN = "logbook"
GROUP_BY_MINUTES = 15
EMPTY_JSON_OBJECT = "{}" EMPTY_JSON_OBJECT = "{}"
UNIT_OF_MEASUREMENT_JSON = '"unit_of_measurement":' UNIT_OF_MEASUREMENT_JSON = '"unit_of_measurement":'
UNIT_OF_MEASUREMENT_JSON_LIKE = f"%{UNIT_OF_MEASUREMENT_JSON}%" UNIT_OF_MEASUREMENT_JSON_LIKE = f"%{UNIT_OF_MEASUREMENT_JSON}%"
@ -435,70 +435,43 @@ def _get_events(
def yield_rows(query: Query) -> Generator[Row, None, None]: def yield_rows(query: Query) -> Generator[Row, None, None]:
"""Yield Events that are not filtered away.""" """Yield Events that are not filtered away."""
for row in query.yield_per(1000): if entity_ids or context_id:
rows = query.all()
else:
rows = query.yield_per(1000)
for row in rows:
context_lookup.setdefault(row.context_id, row) context_lookup.setdefault(row.context_id, row)
if row.event_type != EVENT_CALL_SERVICE and ( event_type = row.event_type
row.event_type == EVENT_STATE_CHANGED if event_type != EVENT_CALL_SERVICE and (
or _keep_row(hass, row, entities_filter) event_type == EVENT_STATE_CHANGED
or _keep_row(hass, event_type, row, entities_filter)
): ):
yield row yield row
if entity_ids is not None: if entity_ids is not None:
entities_filter = generate_filter([], entity_ids, [], []) entities_filter = generate_filter([], entity_ids, [], [])
event_types = [
*ALL_EVENT_TYPES_EXCEPT_STATE_CHANGED,
*hass.data.get(DOMAIN, {}),
]
entity_filter = None
if entity_ids is None and filters:
entity_filter = filters.entity_filter() # type: ignore[no-untyped-call]
stmt = _generate_logbook_query(
start_day,
end_day,
event_types,
entity_ids,
entity_filter,
entity_matches_only,
context_id,
)
with session_scope(hass=hass) as session: with session_scope(hass=hass) as session:
old_state = aliased(States, name="old_state")
query: Query
query = _generate_events_query_without_states(session)
query = _apply_event_time_filter(query, start_day, end_day)
query = _apply_event_types_filter(
hass, query, ALL_EVENT_TYPES_EXCEPT_STATE_CHANGED
)
if entity_ids is not None:
if entity_matches_only:
# When entity_matches_only is provided, contexts and events that do not
# contain the entity_ids are not included in the logbook response.
query = _apply_event_entity_id_matchers(query, entity_ids)
query = query.outerjoin(EventData, (Events.data_id == EventData.data_id))
query = query.union_all(
_generate_states_query(
session, start_day, end_day, old_state, entity_ids
)
)
else:
if context_id is not None:
query = query.filter(Events.context_id == context_id)
query = query.outerjoin(EventData, (Events.data_id == EventData.data_id))
states_query = _generate_states_query(
session, start_day, end_day, old_state, entity_ids
)
unions: list[Query] = []
if context_id is not None:
# Once all the old `state_changed` events
# are gone from the database remove the
# _generate_legacy_events_context_id_query
unions.append(
_generate_legacy_events_context_id_query(
session, context_id, start_day, end_day
)
)
states_query = states_query.outerjoin(
Events, (States.event_id == Events.event_id)
)
states_query = states_query.filter(States.context_id == context_id)
elif filters:
states_query = states_query.filter(filters.entity_filter()) # type: ignore[no-untyped-call]
unions.append(states_query)
query = query.union_all(*unions)
query = query.order_by(Events.time_fired)
return list( return list(
_humanify( _humanify(
hass, hass,
yield_rows(query), yield_rows(session.execute(stmt)),
entity_name_cache, entity_name_cache,
event_cache, event_cache,
context_augmenter, context_augmenter,
@ -506,8 +479,72 @@ def _get_events(
) )
def _generate_events_query_without_data(session: Session) -> Query: def _generate_logbook_query(
return session.query( start_day: dt,
end_day: dt,
event_types: list[str],
entity_ids: list[str] | None = None,
entity_filter: Any | None = None,
entity_matches_only: bool = False,
context_id: str | None = None,
) -> StatementLambdaElement:
"""Generate a logbook query lambda_stmt."""
stmt = lambda_stmt(
lambda: _generate_events_query_without_states()
.where((Events.time_fired > start_day) & (Events.time_fired < end_day))
.where(Events.event_type.in_(event_types))
.outerjoin(EventData, (Events.data_id == EventData.data_id))
)
if entity_ids is not None:
if entity_matches_only:
# When entity_matches_only is provided, contexts and events that do not
# contain the entity_ids are not included in the logbook response.
stmt.add_criteria(
lambda s: s.where(_apply_event_entity_id_matchers(entity_ids)),
track_on=entity_ids,
)
stmt += lambda s: s.union_all(
_generate_states_query()
.filter((States.last_updated > start_day) & (States.last_updated < end_day))
.where(States.entity_id.in_(entity_ids))
)
else:
if context_id is not None:
# Once all the old `state_changed` events
# are gone from the database remove the
# union_all(_generate_legacy_events_context_id_query()....)
stmt += lambda s: s.where(Events.context_id == context_id).union_all(
_generate_legacy_events_context_id_query()
.where((Events.time_fired > start_day) & (Events.time_fired < end_day))
.where(Events.context_id == context_id),
_generate_states_query()
.where(
(States.last_updated > start_day) & (States.last_updated < end_day)
)
.outerjoin(Events, (States.event_id == Events.event_id))
.where(States.context_id == context_id),
)
elif entity_filter is not None:
stmt += lambda s: s.union_all(
_generate_states_query()
.where(
(States.last_updated > start_day) & (States.last_updated < end_day)
)
.where(entity_filter)
)
else:
stmt += lambda s: s.union_all(
_generate_states_query().where(
(States.last_updated > start_day) & (States.last_updated < end_day)
)
)
stmt += lambda s: s.order_by(Events.time_fired)
return stmt
def _generate_events_query_without_data() -> Select:
return select(
literal(value=EVENT_STATE_CHANGED, type_=sqlalchemy.String).label("event_type"), literal(value=EVENT_STATE_CHANGED, type_=sqlalchemy.String).label("event_type"),
literal(value=None, type_=sqlalchemy.Text).label("event_data"), literal(value=None, type_=sqlalchemy.Text).label("event_data"),
States.last_changed.label("time_fired"), States.last_changed.label("time_fired"),
@ -519,65 +556,48 @@ def _generate_events_query_without_data(session: Session) -> Query:
) )
def _generate_legacy_events_context_id_query( def _generate_legacy_events_context_id_query() -> Select:
session: Session,
context_id: str,
start_day: dt,
end_day: dt,
) -> Query:
"""Generate a legacy events context id query that also joins states.""" """Generate a legacy events context id query that also joins states."""
# This can be removed once we no longer have event_ids in the states table # This can be removed once we no longer have event_ids in the states table
legacy_context_id_query = session.query(
*EVENT_COLUMNS,
literal(value=None, type_=sqlalchemy.String).label("shared_data"),
States.state,
States.entity_id,
States.attributes,
StateAttributes.shared_attrs,
)
legacy_context_id_query = _apply_event_time_filter(
legacy_context_id_query, start_day, end_day
)
return ( return (
legacy_context_id_query.filter(Events.context_id == context_id) select(
*EVENT_COLUMNS,
literal(value=None, type_=sqlalchemy.String).label("shared_data"),
States.state,
States.entity_id,
States.attributes,
StateAttributes.shared_attrs,
)
.outerjoin(States, (Events.event_id == States.event_id)) .outerjoin(States, (Events.event_id == States.event_id))
.filter(States.last_updated == States.last_changed) .where(States.last_updated == States.last_changed)
.filter(_not_continuous_entity_matcher()) .where(_not_continuous_entity_matcher())
.outerjoin( .outerjoin(
StateAttributes, (States.attributes_id == StateAttributes.attributes_id) StateAttributes, (States.attributes_id == StateAttributes.attributes_id)
) )
) )
def _generate_events_query_without_states(session: Session) -> Query: def _generate_events_query_without_states() -> Select:
return session.query( return select(
*EVENT_COLUMNS, EventData.shared_data.label("shared_data"), *EMPTY_STATE_COLUMNS *EVENT_COLUMNS, EventData.shared_data.label("shared_data"), *EMPTY_STATE_COLUMNS
) )
def _generate_states_query( def _generate_states_query() -> Select:
session: Session, old_state = aliased(States, name="old_state")
start_day: dt, return (
end_day: dt, _generate_events_query_without_data()
old_state: States,
entity_ids: Iterable[str] | None,
) -> Query:
query = (
_generate_events_query_without_data(session)
.outerjoin(old_state, (States.old_state_id == old_state.state_id)) .outerjoin(old_state, (States.old_state_id == old_state.state_id))
.filter(_missing_state_matcher(old_state)) .where(_missing_state_matcher(old_state))
.filter(_not_continuous_entity_matcher()) .where(_not_continuous_entity_matcher())
.filter((States.last_updated > start_day) & (States.last_updated < end_day)) .where(States.last_updated == States.last_changed)
.filter(States.last_updated == States.last_changed) .outerjoin(
) StateAttributes, (States.attributes_id == StateAttributes.attributes_id)
if entity_ids: )
query = query.filter(States.entity_id.in_(entity_ids))
return query.outerjoin(
StateAttributes, (States.attributes_id == StateAttributes.attributes_id)
) )
def _missing_state_matcher(old_state: States) -> Any: def _missing_state_matcher(old_state: States) -> sqlalchemy.and_:
# The below removes state change events that do not have # The below removes state change events that do not have
# and old_state or the old_state is missing (newly added entities) # and old_state or the old_state is missing (newly added entities)
# or the new_state is missing (removed entities) # or the new_state is missing (removed entities)
@ -588,7 +608,7 @@ def _missing_state_matcher(old_state: States) -> Any:
) )
def _not_continuous_entity_matcher() -> Any: def _not_continuous_entity_matcher() -> sqlalchemy.or_:
"""Match non continuous entities.""" """Match non continuous entities."""
return sqlalchemy.or_( return sqlalchemy.or_(
_not_continuous_domain_matcher(), _not_continuous_domain_matcher(),
@ -598,7 +618,7 @@ def _not_continuous_entity_matcher() -> Any:
) )
def _not_continuous_domain_matcher() -> Any: def _not_continuous_domain_matcher() -> sqlalchemy.and_:
"""Match not continuous domains.""" """Match not continuous domains."""
return sqlalchemy.and_( return sqlalchemy.and_(
*[ *[
@ -608,7 +628,7 @@ def _not_continuous_domain_matcher() -> Any:
).self_group() ).self_group()
def _continuous_domain_matcher() -> Any: def _continuous_domain_matcher() -> sqlalchemy.or_:
"""Match continuous domains.""" """Match continuous domains."""
return sqlalchemy.or_( return sqlalchemy.or_(
*[ *[
@ -625,37 +645,22 @@ def _not_uom_attributes_matcher() -> Any:
) | ~States.attributes.like(UNIT_OF_MEASUREMENT_JSON_LIKE) ) | ~States.attributes.like(UNIT_OF_MEASUREMENT_JSON_LIKE)
def _apply_event_time_filter(events_query: Query, start_day: dt, end_day: dt) -> Query: def _apply_event_entity_id_matchers(entity_ids: Iterable[str]) -> sqlalchemy.or_:
return events_query.filter( """Create matchers for the entity_id in the event_data."""
(Events.time_fired > start_day) & (Events.time_fired < end_day)
)
def _apply_event_types_filter(
hass: HomeAssistant, query: Query, event_types: list[str]
) -> Query:
return query.filter(
Events.event_type.in_(event_types + list(hass.data.get(DOMAIN, {})))
)
def _apply_event_entity_id_matchers(
events_query: Query, entity_ids: Iterable[str]
) -> Query:
ors = [] ors = []
for entity_id in entity_ids: for entity_id in entity_ids:
like = ENTITY_ID_JSON_TEMPLATE.format(entity_id) like = ENTITY_ID_JSON_TEMPLATE.format(entity_id)
ors.append(Events.event_data.like(like)) ors.append(Events.event_data.like(like))
ors.append(EventData.shared_data.like(like)) ors.append(EventData.shared_data.like(like))
return events_query.filter(sqlalchemy.or_(*ors)) return sqlalchemy.or_(*ors)
def _keep_row( def _keep_row(
hass: HomeAssistant, hass: HomeAssistant,
event_type: str,
row: Row, row: Row,
entities_filter: EntityFilter | Callable[[str], bool] | None = None, entities_filter: EntityFilter | Callable[[str], bool] | None = None,
) -> bool: ) -> bool:
event_type = row.event_type
if event_type in HOMEASSISTANT_EVENTS: if event_type in HOMEASSISTANT_EVENTS:
return entities_filter is None or entities_filter(HA_DOMAIN_ENTITY_ID) return entities_filter is None or entities_filter(HA_DOMAIN_ENTITY_ID)

View File

@ -1390,6 +1390,59 @@ async def test_logbook_entity_matches_only(hass, hass_client, recorder_mock):
assert json_dict[1]["context_user_id"] == "9400facee45711eaa9308bfd3d19e474" assert json_dict[1]["context_user_id"] == "9400facee45711eaa9308bfd3d19e474"
async def test_logbook_entity_matches_only_multiple_calls(
hass, hass_client, recorder_mock
):
"""Test the logbook view with a single entity and entity_matches_only called multiple times."""
await async_setup_component(hass, "logbook", {})
await async_setup_component(hass, "automation", {})
await async_recorder_block_till_done(hass)
await hass.async_block_till_done()
await hass.async_start()
await hass.async_block_till_done()
for automation_id in range(5):
hass.bus.async_fire(
EVENT_AUTOMATION_TRIGGERED,
{
ATTR_NAME: f"Mock automation {automation_id}",
ATTR_ENTITY_ID: f"automation.mock_{automation_id}_automation",
},
)
await async_wait_recording_done(hass)
client = await hass_client()
# Today time 00:00:00
start = dt_util.utcnow().date()
start_date = datetime(start.year, start.month, start.day)
end_time = start + timedelta(hours=24)
for automation_id in range(5):
# Test today entries with filter by end_time
response = await client.get(
f"/api/logbook/{start_date.isoformat()}?end_time={end_time}&entity=automation.mock_{automation_id}_automation&entity_matches_only"
)
assert response.status == HTTPStatus.OK
json_dict = await response.json()
assert len(json_dict) == 1
assert (
json_dict[0]["entity_id"] == f"automation.mock_{automation_id}_automation"
)
response = await client.get(
f"/api/logbook/{start_date.isoformat()}?end_time={end_time}&entity=automation.mock_0_automation,automation.mock_1_automation,automation.mock_2_automation&entity_matches_only"
)
assert response.status == HTTPStatus.OK
json_dict = await response.json()
assert len(json_dict) == 3
assert json_dict[0]["entity_id"] == "automation.mock_0_automation"
assert json_dict[1]["entity_id"] == "automation.mock_1_automation"
assert json_dict[2]["entity_id"] == "automation.mock_2_automation"
async def test_custom_log_entry_discoverable_via_entity_matches_only( async def test_custom_log_entry_discoverable_via_entity_matches_only(
hass, hass_client, recorder_mock hass, hass_client, recorder_mock
): ):