From 1f753ecd88d4453edf02154c70fc408741fc993f Mon Sep 17 00:00:00 2001 From: "J. Nick Koston" Date: Sun, 15 May 2022 01:04:23 -0500 Subject: [PATCH] Relocate sqlalchemy filter builder to recorder/filters.py (#71883) --- homeassistant/components/history/__init__.py | 129 +------------------ homeassistant/components/logbook/__init__.py | 6 +- homeassistant/components/logbook/queries.py | 8 +- homeassistant/components/recorder/filters.py | 119 +++++++++++++++++ homeassistant/components/recorder/history.py | 15 ++- tests/components/history/conftest.py | 13 +- tests/components/history/test_init.py | 70 +++++----- 7 files changed, 177 insertions(+), 183 deletions(-) create mode 100644 homeassistant/components/recorder/filters.py diff --git a/homeassistant/components/history/__init__.py b/homeassistant/components/history/__init__.py index 2ebe6405a7a..27acff54f99 100644 --- a/homeassistant/components/history/__init__.py +++ b/homeassistant/components/history/__init__.py @@ -6,20 +6,17 @@ from datetime import datetime as dt, timedelta from http import HTTPStatus import logging import time -from typing import Any, Literal, cast +from typing import Literal, cast from aiohttp import web -from sqlalchemy import not_, or_ -from sqlalchemy.ext.baked import BakedQuery -from sqlalchemy.orm import Query import voluptuous as vol from homeassistant.components import frontend, websocket_api from homeassistant.components.http import HomeAssistantView -from homeassistant.components.recorder import ( - get_instance, - history, - models as history_models, +from homeassistant.components.recorder import get_instance, history +from homeassistant.components.recorder.filters import ( + Filters, + sqlalchemy_filter_from_include_exclude_conf, ) from homeassistant.components.recorder.statistics import ( list_statistic_ids, @@ -28,13 +25,9 @@ from homeassistant.components.recorder.statistics import ( from homeassistant.components.recorder.util import session_scope from homeassistant.components.websocket_api import messages from homeassistant.components.websocket_api.const import JSON_DUMP -from homeassistant.const import CONF_DOMAINS, CONF_ENTITIES, CONF_EXCLUDE, CONF_INCLUDE from homeassistant.core import HomeAssistant import homeassistant.helpers.config_validation as cv -from homeassistant.helpers.entityfilter import ( - CONF_ENTITY_GLOBS, - INCLUDE_EXCLUDE_BASE_FILTER_SCHEMA, -) +from homeassistant.helpers.entityfilter import INCLUDE_EXCLUDE_BASE_FILTER_SCHEMA from homeassistant.helpers.typing import ConfigType import homeassistant.util.dt as dt_util @@ -46,10 +39,6 @@ HISTORY_USE_INCLUDE_ORDER = "history_use_include_order" CONF_ORDER = "use_include_order" -GLOB_TO_SQL_CHARS = { - 42: "%", # * - 46: "_", # . -} CONFIG_SCHEMA = vol.Schema( { @@ -410,112 +399,6 @@ class HistoryPeriodView(HomeAssistantView): return self.json(sorted_result) -def sqlalchemy_filter_from_include_exclude_conf(conf: ConfigType) -> Filters | None: - """Build a sql filter from config.""" - filters = Filters() - if exclude := conf.get(CONF_EXCLUDE): - filters.excluded_entities = exclude.get(CONF_ENTITIES, []) - filters.excluded_domains = exclude.get(CONF_DOMAINS, []) - filters.excluded_entity_globs = exclude.get(CONF_ENTITY_GLOBS, []) - if include := conf.get(CONF_INCLUDE): - filters.included_entities = include.get(CONF_ENTITIES, []) - filters.included_domains = include.get(CONF_DOMAINS, []) - filters.included_entity_globs = include.get(CONF_ENTITY_GLOBS, []) - - return filters if filters.has_config else None - - -class Filters: - """Container for the configured include and exclude filters.""" - - def __init__(self) -> None: - """Initialise the include and exclude filters.""" - self.excluded_entities: list[str] = [] - self.excluded_domains: list[str] = [] - self.excluded_entity_globs: list[str] = [] - - self.included_entities: list[str] = [] - self.included_domains: list[str] = [] - self.included_entity_globs: list[str] = [] - - def apply(self, query: Query) -> Query: - """Apply the entity filter.""" - if not self.has_config: - return query - - return query.filter(self.entity_filter()) - - @property - def has_config(self) -> bool: - """Determine if there is any filter configuration.""" - return bool( - self.excluded_entities - or self.excluded_domains - or self.excluded_entity_globs - or self.included_entities - or self.included_domains - or self.included_entity_globs - ) - - def bake(self, baked_query: BakedQuery) -> None: - """Update a baked query. - - Works the same as apply on a baked_query. - """ - if not self.has_config: - return - - baked_query += lambda q: q.filter(self.entity_filter()) - - def entity_filter(self) -> Any: - """Generate the entity filter query.""" - includes = [] - if self.included_domains: - includes.append( - or_( - *[ - history_models.States.entity_id.like(f"{domain}.%") - for domain in self.included_domains - ] - ).self_group() - ) - if self.included_entities: - includes.append(history_models.States.entity_id.in_(self.included_entities)) - for glob in self.included_entity_globs: - includes.append(_glob_to_like(glob)) - - excludes = [] - if self.excluded_domains: - excludes.append( - or_( - *[ - history_models.States.entity_id.like(f"{domain}.%") - for domain in self.excluded_domains - ] - ).self_group() - ) - if self.excluded_entities: - excludes.append(history_models.States.entity_id.in_(self.excluded_entities)) - for glob in self.excluded_entity_globs: - excludes.append(_glob_to_like(glob)) - - if not includes and not excludes: - return None - - if includes and not excludes: - return or_(*includes) - - if not includes and excludes: - return not_(or_(*excludes)) - - return or_(*includes) & not_(or_(*excludes)) - - -def _glob_to_like(glob_str: str) -> Any: - """Translate glob to sql.""" - return history_models.States.entity_id.like(glob_str.translate(GLOB_TO_SQL_CHARS)) - - def _entities_may_have_state_changes_after( hass: HomeAssistant, entity_ids: Iterable, start_time: dt ) -> bool: diff --git a/homeassistant/components/logbook/__init__.py b/homeassistant/components/logbook/__init__.py index 809f13ca598..4bef1f1a23d 100644 --- a/homeassistant/components/logbook/__init__.py +++ b/homeassistant/components/logbook/__init__.py @@ -17,12 +17,12 @@ import voluptuous as vol from homeassistant.components import frontend, websocket_api from homeassistant.components.automation import EVENT_AUTOMATION_TRIGGERED -from homeassistant.components.history import ( +from homeassistant.components.http import HomeAssistantView +from homeassistant.components.recorder import get_instance +from homeassistant.components.recorder.filters import ( Filters, sqlalchemy_filter_from_include_exclude_conf, ) -from homeassistant.components.http import HomeAssistantView -from homeassistant.components.recorder import get_instance from homeassistant.components.recorder.models import ( process_datetime_to_timestamp, process_timestamp_to_utc_isoformat, diff --git a/homeassistant/components/logbook/queries.py b/homeassistant/components/logbook/queries.py index 29dac31a432..89c530aec43 100644 --- a/homeassistant/components/logbook/queries.py +++ b/homeassistant/components/logbook/queries.py @@ -3,17 +3,17 @@ from __future__ import annotations from collections.abc import Iterable from datetime import datetime as dt -from typing import Any import sqlalchemy from sqlalchemy import lambda_stmt, select, union_all from sqlalchemy.orm import Query, aliased +from sqlalchemy.sql.elements import ClauseList from sqlalchemy.sql.expression import literal from sqlalchemy.sql.lambdas import StatementLambdaElement from sqlalchemy.sql.selectable import Select -from homeassistant.components.history import Filters from homeassistant.components.proximity import DOMAIN as PROXIMITY_DOMAIN +from homeassistant.components.recorder.filters import Filters from homeassistant.components.recorder.models import ( ENTITY_ID_LAST_UPDATED_INDEX, LAST_UPDATED_INDEX, @@ -236,7 +236,7 @@ def _all_stmt( start_day: dt, end_day: dt, event_types: tuple[str, ...], - entity_filter: Any | None = None, + entity_filter: ClauseList | None = None, context_id: str | None = None, ) -> StatementLambdaElement: """Generate a logbook query for all entities.""" @@ -410,7 +410,7 @@ def _continuous_domain_matcher() -> sqlalchemy.or_: ).self_group() -def _not_uom_attributes_matcher() -> Any: +def _not_uom_attributes_matcher() -> ClauseList: """Prefilter ATTR_UNIT_OF_MEASUREMENT as its much faster in sql.""" return ~StateAttributes.shared_attrs.like( UNIT_OF_MEASUREMENT_JSON_LIKE diff --git a/homeassistant/components/recorder/filters.py b/homeassistant/components/recorder/filters.py new file mode 100644 index 00000000000..bb19dfc6d62 --- /dev/null +++ b/homeassistant/components/recorder/filters.py @@ -0,0 +1,119 @@ +"""Provide pre-made queries on top of the recorder component.""" +from __future__ import annotations + +from sqlalchemy import not_, or_ +from sqlalchemy.ext.baked import BakedQuery +from sqlalchemy.sql.elements import ClauseList + +from homeassistant.const import CONF_DOMAINS, CONF_ENTITIES, CONF_EXCLUDE, CONF_INCLUDE +from homeassistant.helpers.entityfilter import CONF_ENTITY_GLOBS +from homeassistant.helpers.typing import ConfigType + +from .models import States + +DOMAIN = "history" +HISTORY_FILTERS = "history_filters" + +GLOB_TO_SQL_CHARS = { + 42: "%", # * + 46: "_", # . +} + + +def sqlalchemy_filter_from_include_exclude_conf(conf: ConfigType) -> Filters | None: + """Build a sql filter from config.""" + filters = Filters() + if exclude := conf.get(CONF_EXCLUDE): + filters.excluded_entities = exclude.get(CONF_ENTITIES, []) + filters.excluded_domains = exclude.get(CONF_DOMAINS, []) + filters.excluded_entity_globs = exclude.get(CONF_ENTITY_GLOBS, []) + if include := conf.get(CONF_INCLUDE): + filters.included_entities = include.get(CONF_ENTITIES, []) + filters.included_domains = include.get(CONF_DOMAINS, []) + filters.included_entity_globs = include.get(CONF_ENTITY_GLOBS, []) + + return filters if filters.has_config else None + + +class Filters: + """Container for the configured include and exclude filters.""" + + def __init__(self) -> None: + """Initialise the include and exclude filters.""" + self.excluded_entities: list[str] = [] + self.excluded_domains: list[str] = [] + self.excluded_entity_globs: list[str] = [] + + self.included_entities: list[str] = [] + self.included_domains: list[str] = [] + self.included_entity_globs: list[str] = [] + + @property + def has_config(self) -> bool: + """Determine if there is any filter configuration.""" + return bool( + self.excluded_entities + or self.excluded_domains + or self.excluded_entity_globs + or self.included_entities + or self.included_domains + or self.included_entity_globs + ) + + def bake(self, baked_query: BakedQuery) -> BakedQuery: + """Update a baked query. + + Works the same as apply on a baked_query. + """ + if not self.has_config: + return + + baked_query += lambda q: q.filter(self.entity_filter()) + + def entity_filter(self) -> ClauseList: + """Generate the entity filter query.""" + includes = [] + if self.included_domains: + includes.append( + or_( + *[ + States.entity_id.like(f"{domain}.%") + for domain in self.included_domains + ] + ).self_group() + ) + if self.included_entities: + includes.append(States.entity_id.in_(self.included_entities)) + for glob in self.included_entity_globs: + includes.append(_glob_to_like(glob)) + + excludes = [] + if self.excluded_domains: + excludes.append( + or_( + *[ + States.entity_id.like(f"{domain}.%") + for domain in self.excluded_domains + ] + ).self_group() + ) + if self.excluded_entities: + excludes.append(States.entity_id.in_(self.excluded_entities)) + for glob in self.excluded_entity_globs: + excludes.append(_glob_to_like(glob)) + + if not includes and not excludes: + return None + + if includes and not excludes: + return or_(*includes) + + if not includes and excludes: + return not_(or_(*excludes)) + + return or_(*includes) & not_(or_(*excludes)) + + +def _glob_to_like(glob_str: str) -> ClauseList: + """Translate glob to sql.""" + return States.entity_id.like(glob_str.translate(GLOB_TO_SQL_CHARS)) diff --git a/homeassistant/components/recorder/history.py b/homeassistant/components/recorder/history.py index 7def35ce3ac..316e5ab27c8 100644 --- a/homeassistant/components/recorder/history.py +++ b/homeassistant/components/recorder/history.py @@ -25,6 +25,7 @@ from homeassistant.components.websocket_api.const import ( from homeassistant.core import HomeAssistant, State, split_entity_id import homeassistant.util.dt as dt_util +from .filters import Filters from .models import ( LazyState, RecorderRuns, @@ -163,7 +164,7 @@ def get_significant_states( start_time: datetime, end_time: datetime | None = None, entity_ids: list[str] | None = None, - filters: Any | None = None, + filters: Filters | None = None, include_start_time_state: bool = True, significant_changes_only: bool = True, minimal_response: bool = False, @@ -205,7 +206,7 @@ def _query_significant_states_with_session( start_time: datetime, end_time: datetime | None = None, entity_ids: list[str] | None = None, - filters: Any = None, + filters: Filters | None = None, significant_changes_only: bool = True, no_attributes: bool = False, ) -> list[Row]: @@ -281,7 +282,7 @@ def get_significant_states_with_session( start_time: datetime, end_time: datetime | None = None, entity_ids: list[str] | None = None, - filters: Any = None, + filters: Filters | None = None, include_start_time_state: bool = True, significant_changes_only: bool = True, minimal_response: bool = False, @@ -330,7 +331,7 @@ def get_full_significant_states_with_session( start_time: datetime, end_time: datetime | None = None, entity_ids: list[str] | None = None, - filters: Any = None, + filters: Filters | None = None, include_start_time_state: bool = True, significant_changes_only: bool = True, no_attributes: bool = False, @@ -549,7 +550,7 @@ def _most_recent_state_ids_subquery(query: Query) -> Query: def _get_states_baked_query_for_all( hass: HomeAssistant, - filters: Any | None = None, + filters: Filters | None = None, no_attributes: bool = False, ) -> BakedQuery: """Baked query to get states for all entities.""" @@ -573,7 +574,7 @@ def _get_rows_with_session( utc_point_in_time: datetime, entity_ids: list[str] | None = None, run: RecorderRuns | None = None, - filters: Any | None = None, + filters: Filters | None = None, no_attributes: bool = False, ) -> list[Row]: """Return the states at a specific point in time.""" @@ -640,7 +641,7 @@ def _sorted_states_to_dict( states: Iterable[Row], start_time: datetime, entity_ids: list[str] | None, - filters: Any = None, + filters: Filters | None = None, include_start_time_state: bool = True, minimal_response: bool = False, no_attributes: bool = False, diff --git a/tests/components/history/conftest.py b/tests/components/history/conftest.py index 5e81b444393..a2916153acc 100644 --- a/tests/components/history/conftest.py +++ b/tests/components/history/conftest.py @@ -2,6 +2,7 @@ import pytest from homeassistant.components import history +from homeassistant.const import CONF_DOMAINS, CONF_ENTITIES, CONF_EXCLUDE, CONF_INCLUDE from homeassistant.setup import setup_component @@ -13,13 +14,13 @@ def hass_history(hass_recorder): config = history.CONFIG_SCHEMA( { history.DOMAIN: { - history.CONF_INCLUDE: { - history.CONF_DOMAINS: ["media_player"], - history.CONF_ENTITIES: ["thermostat.test"], + CONF_INCLUDE: { + CONF_DOMAINS: ["media_player"], + CONF_ENTITIES: ["thermostat.test"], }, - history.CONF_EXCLUDE: { - history.CONF_DOMAINS: ["thermostat"], - history.CONF_ENTITIES: ["media_player.test"], + CONF_EXCLUDE: { + CONF_DOMAINS: ["thermostat"], + CONF_ENTITIES: ["media_player.test"], }, } } diff --git a/tests/components/history/test_init.py b/tests/components/history/test_init.py index 23e0550d6aa..0425c9bc2e7 100644 --- a/tests/components/history/test_init.py +++ b/tests/components/history/test_init.py @@ -11,7 +11,7 @@ from pytest import approx from homeassistant.components import history from homeassistant.components.recorder.history import get_significant_states from homeassistant.components.recorder.models import process_timestamp -from homeassistant.const import CONF_DOMAINS, CONF_ENTITIES +from homeassistant.const import CONF_DOMAINS, CONF_ENTITIES, CONF_EXCLUDE, CONF_INCLUDE import homeassistant.core as ha from homeassistant.helpers.json import JSONEncoder from homeassistant.setup import async_setup_component @@ -186,9 +186,7 @@ def test_get_significant_states_exclude_domain(hass_history): config = history.CONFIG_SCHEMA( { ha.DOMAIN: {}, - history.DOMAIN: { - history.CONF_EXCLUDE: {history.CONF_DOMAINS: ["media_player"]} - }, + history.DOMAIN: {CONF_EXCLUDE: {CONF_DOMAINS: ["media_player"]}}, } ) check_significant_states(hass, zero, four, states, config) @@ -207,9 +205,7 @@ def test_get_significant_states_exclude_entity(hass_history): config = history.CONFIG_SCHEMA( { ha.DOMAIN: {}, - history.DOMAIN: { - history.CONF_EXCLUDE: {history.CONF_ENTITIES: ["media_player.test"]} - }, + history.DOMAIN: {CONF_EXCLUDE: {CONF_ENTITIES: ["media_player.test"]}}, } ) check_significant_states(hass, zero, four, states, config) @@ -230,9 +226,9 @@ def test_get_significant_states_exclude(hass_history): { ha.DOMAIN: {}, history.DOMAIN: { - history.CONF_EXCLUDE: { - history.CONF_DOMAINS: ["thermostat"], - history.CONF_ENTITIES: ["media_player.test"], + CONF_EXCLUDE: { + CONF_DOMAINS: ["thermostat"], + CONF_ENTITIES: ["media_player.test"], } }, } @@ -257,10 +253,8 @@ def test_get_significant_states_exclude_include_entity(hass_history): { ha.DOMAIN: {}, history.DOMAIN: { - history.CONF_INCLUDE: { - history.CONF_ENTITIES: ["media_player.test", "thermostat.test"] - }, - history.CONF_EXCLUDE: {history.CONF_DOMAINS: ["thermostat"]}, + CONF_INCLUDE: {CONF_ENTITIES: ["media_player.test", "thermostat.test"]}, + CONF_EXCLUDE: {CONF_DOMAINS: ["thermostat"]}, }, } ) @@ -282,9 +276,7 @@ def test_get_significant_states_include_domain(hass_history): config = history.CONFIG_SCHEMA( { ha.DOMAIN: {}, - history.DOMAIN: { - history.CONF_INCLUDE: {history.CONF_DOMAINS: ["thermostat", "script"]} - }, + history.DOMAIN: {CONF_INCLUDE: {CONF_DOMAINS: ["thermostat", "script"]}}, } ) check_significant_states(hass, zero, four, states, config) @@ -306,9 +298,7 @@ def test_get_significant_states_include_entity(hass_history): config = history.CONFIG_SCHEMA( { ha.DOMAIN: {}, - history.DOMAIN: { - history.CONF_INCLUDE: {history.CONF_ENTITIES: ["media_player.test"]} - }, + history.DOMAIN: {CONF_INCLUDE: {CONF_ENTITIES: ["media_player.test"]}}, } ) check_significant_states(hass, zero, four, states, config) @@ -330,9 +320,9 @@ def test_get_significant_states_include(hass_history): { ha.DOMAIN: {}, history.DOMAIN: { - history.CONF_INCLUDE: { - history.CONF_DOMAINS: ["thermostat"], - history.CONF_ENTITIES: ["media_player.test"], + CONF_INCLUDE: { + CONF_DOMAINS: ["thermostat"], + CONF_ENTITIES: ["media_player.test"], } }, } @@ -359,8 +349,8 @@ def test_get_significant_states_include_exclude_domain(hass_history): { ha.DOMAIN: {}, history.DOMAIN: { - history.CONF_INCLUDE: {history.CONF_DOMAINS: ["media_player"]}, - history.CONF_EXCLUDE: {history.CONF_DOMAINS: ["media_player"]}, + CONF_INCLUDE: {CONF_DOMAINS: ["media_player"]}, + CONF_EXCLUDE: {CONF_DOMAINS: ["media_player"]}, }, } ) @@ -386,8 +376,8 @@ def test_get_significant_states_include_exclude_entity(hass_history): { ha.DOMAIN: {}, history.DOMAIN: { - history.CONF_INCLUDE: {history.CONF_ENTITIES: ["media_player.test"]}, - history.CONF_EXCLUDE: {history.CONF_ENTITIES: ["media_player.test"]}, + CONF_INCLUDE: {CONF_ENTITIES: ["media_player.test"]}, + CONF_EXCLUDE: {CONF_ENTITIES: ["media_player.test"]}, }, } ) @@ -410,13 +400,13 @@ def test_get_significant_states_include_exclude(hass_history): { ha.DOMAIN: {}, history.DOMAIN: { - history.CONF_INCLUDE: { - history.CONF_DOMAINS: ["media_player"], - history.CONF_ENTITIES: ["thermostat.test"], + CONF_INCLUDE: { + CONF_DOMAINS: ["media_player"], + CONF_ENTITIES: ["thermostat.test"], }, - history.CONF_EXCLUDE: { - history.CONF_DOMAINS: ["thermostat"], - history.CONF_ENTITIES: ["media_player.test"], + CONF_EXCLUDE: { + CONF_DOMAINS: ["thermostat"], + CONF_ENTITIES: ["media_player.test"], }, }, } @@ -503,14 +493,14 @@ def test_get_significant_states_only(hass_history): def check_significant_states(hass, zero, four, states, config): """Check if significant states are retrieved.""" filters = history.Filters() - exclude = config[history.DOMAIN].get(history.CONF_EXCLUDE) + exclude = config[history.DOMAIN].get(CONF_EXCLUDE) if exclude: - filters.excluded_entities = exclude.get(history.CONF_ENTITIES, []) - filters.excluded_domains = exclude.get(history.CONF_DOMAINS, []) - include = config[history.DOMAIN].get(history.CONF_INCLUDE) + filters.excluded_entities = exclude.get(CONF_ENTITIES, []) + filters.excluded_domains = exclude.get(CONF_DOMAINS, []) + include = config[history.DOMAIN].get(CONF_INCLUDE) if include: - filters.included_entities = include.get(history.CONF_ENTITIES, []) - filters.included_domains = include.get(history.CONF_DOMAINS, []) + filters.included_entities = include.get(CONF_ENTITIES, []) + filters.included_domains = include.get(CONF_DOMAINS, []) hist = get_significant_states(hass, zero, four, filters=filters) assert states == hist @@ -1496,7 +1486,7 @@ async def test_history_during_period_with_use_include_order( { history.DOMAIN: { history.CONF_ORDER: True, - history.CONF_INCLUDE: { + CONF_INCLUDE: { CONF_ENTITIES: sort_order, CONF_DOMAINS: ["sensor"], },