Ensure filters are generated inside the lambda locks (#90418)

This commit is contained in:
J. Nick Koston 2023-03-28 08:50:10 -10:00 committed by GitHub
parent 9ccd43e5f1
commit d21433b6af
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 54 additions and 34 deletions

View File

@ -34,16 +34,11 @@ def statement_for_request(
# limited by the context_id and the yaml configured filter
if not entity_ids and not device_ids:
context_id_bin = ulid_to_bytes_or_none(context_id)
states_entity_filter = (
filters.states_metadata_entity_filter() if filters else None
)
events_entity_filter = filters.events_entity_filter() if filters else None
return all_stmt(
start_day,
end_day,
event_types,
states_entity_filter,
events_entity_filter,
filters,
context_id_bin,
)

View File

@ -2,7 +2,6 @@
from __future__ import annotations
from sqlalchemy import lambda_stmt
from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.lambdas import StatementLambdaElement
from sqlalchemy.sql.selectable import Select
@ -11,6 +10,7 @@ from homeassistant.components.recorder.db_schema import (
Events,
States,
)
from homeassistant.components.recorder.filters import Filters
from .common import apply_states_filters, select_events_without_states, select_states
@ -19,8 +19,7 @@ def all_stmt(
start_day: float,
end_day: float,
event_types: tuple[str, ...],
states_entity_filter: ColumnElement | None = None,
events_entity_filter: ColumnElement | None = None,
filters: Filters | None,
context_id_bin: bytes | None = None,
) -> StatementLambdaElement:
"""Generate a logbook query for all entities."""
@ -36,19 +35,17 @@ def all_stmt(
context_id_bin, # type:ignore[arg-type]
),
)
else:
if events_entity_filter is not None:
stmt += lambda s: s.where(events_entity_filter)
if states_entity_filter is not None:
stmt += lambda s: s.union_all(
elif filters and filters.has_config:
stmt = stmt.add_criteria(
lambda q: q.filter(filters.events_entity_filter()).union_all( # type: ignore[union-attr]
_states_query_for_all(start_day, end_day).where(
# https://github.com/python/mypy/issues/2608
states_entity_filter # type:ignore[arg-type]
filters.states_metadata_entity_filter() # type: ignore[union-attr]
)
)
else:
stmt += lambda s: s.union_all(_states_query_for_all(start_day, end_day))
),
track_on=[filters],
)
else:
stmt += lambda s: s.union_all(_states_query_for_all(start_day, end_day))
stmt += lambda s: s.order_by(Events.time_fired_ts)
return stmt

View File

@ -125,8 +125,8 @@ class Filters:
def _generate_filter_for_columns(
self, columns: Iterable[Column], encoder: Callable[[Any], Any]
) -> ColumnElement | None:
"""Generate a filter from pre-comuted sets and pattern lists.
) -> ColumnElement:
"""Generate a filter from pre-computed sets and pattern lists.
This must match exactly how homeassistant.helpers.entityfilter works.
"""
@ -146,7 +146,9 @@ class Filters:
# Case 1 - No filter
# - All entities included
if not have_include and not have_exclude:
return None
raise RuntimeError(
"No filter configuration provided, check has_config before calling this method."
)
# Case 2 - Only includes
# - Entity listed in entities include: include
@ -193,7 +195,7 @@ class Filters:
# - Otherwise: exclude
return i_entities
def states_entity_filter(self) -> ColumnElement | None:
def states_entity_filter(self) -> ColumnElement:
"""Generate the States.entity_id filter query.
This is no longer used except by the legacy queries.
@ -206,7 +208,7 @@ class Filters:
# The type annotation should be improved so the type ignore can be removed
return self._generate_filter_for_columns((States.entity_id,), _encoder) # type: ignore[arg-type]
def states_metadata_entity_filter(self) -> ColumnElement | None:
def states_metadata_entity_filter(self) -> ColumnElement:
"""Generate the StatesMeta.entity_id filter query."""
def _encoder(data: Any) -> Any:
@ -232,7 +234,7 @@ class Filters:
(OLD_ENTITY_ID_IN_EVENT == JSON_NULL) | OLD_ENTITY_ID_IN_EVENT.is_(None)
),
# Needs https://github.com/bdraco/home-assistant/commit/bba91945006a46f3a01870008eb048e4f9cbb1ef
self._generate_filter_for_columns( # type: ignore[union-attr]
self._generate_filter_for_columns(
(ENTITY_ID_IN_EVENT, OLD_ENTITY_ID_IN_EVENT), _encoder # type: ignore[arg-type]
).self_group(),
)

View File

@ -306,9 +306,8 @@ def _significant_states_stmt(
else:
stmt += _ignore_domains_filter
if filters and filters.has_config:
entity_filter = filters.states_entity_filter()
stmt = stmt.add_criteria(
lambda q: q.filter(entity_filter), track_on=[filters]
lambda q: q.filter(filters.states_entity_filter()), track_on=[filters] # type: ignore[union-attr]
)
if schema_version >= 31:
@ -713,8 +712,9 @@ def _get_states_for_all_stmt(
)
stmt += _ignore_domains_filter
if filters and filters.has_config:
entity_filter = filters.states_entity_filter()
stmt = stmt.add_criteria(lambda q: q.filter(entity_filter), track_on=[filters])
stmt = stmt.add_criteria(
lambda q: q.filter(filters.states_entity_filter()), track_on=[filters] # type: ignore[union-attr]
)
if join_attributes:
stmt += lambda q: q.outerjoin(
StateAttributes, (States.attributes_id == StateAttributes.attributes_id)

View File

@ -192,9 +192,9 @@ def _significant_states_stmt(
else:
stmt += _ignore_domains_filter
if filters and filters.has_config:
entity_filter = filters.states_metadata_entity_filter()
stmt = stmt.add_criteria(
lambda q: q.filter(entity_filter), track_on=[filters]
lambda q: q.filter(filters.states_metadata_entity_filter()), # type: ignore[union-attr]
track_on=[filters],
)
join_states_meta = True
@ -567,8 +567,10 @@ def _get_states_for_all_stmt(
)
stmt += _ignore_domains_filter
if filters and filters.has_config:
entity_filter = filters.states_metadata_entity_filter()
stmt = stmt.add_criteria(lambda q: q.filter(entity_filter), track_on=[filters])
stmt = stmt.add_criteria(
lambda q: q.filter(filters.states_metadata_entity_filter()), # type: ignore[union-attr]
track_on=[filters],
)
if join_attributes:
stmt += lambda q: q.outerjoin(
StateAttributes, (States.attributes_id == StateAttributes.attributes_id)

View File

@ -1,6 +1,9 @@
"""The tests for recorder filters."""
import pytest
from homeassistant.components.recorder.filters import (
Filters,
extract_include_exclude_filter_conf,
merge_include_exclude_filters,
)
@ -132,3 +135,24 @@ def test_merge_include_exclude_filters() -> None:
CONF_ENTITY_GLOBS: {"climate.*", "not_climate.*"},
},
}
async def test_an_empty_filter_raises() -> None:
"""Test empty filter raises when not guarding with has_config."""
filters = Filters()
assert not filters.has_config
with pytest.raises(
RuntimeError,
match="No filter configuration provided, check has_config before calling this method",
):
filters.states_metadata_entity_filter()
with pytest.raises(
RuntimeError,
match="No filter configuration provided, check has_config before calling this method",
):
filters.states_entity_filter()
with pytest.raises(
RuntimeError,
match="No filter configuration provided, check has_config before calling this method",
):
filters.events_entity_filter()