Add explict type casts for postgresql filters (#72615)

This commit is contained in:
J. Nick Koston 2022-05-27 08:11:33 -10:00 committed by Paulus Schoutsen
parent 38c085f869
commit 13f953f49d

View File

@ -5,7 +5,7 @@ from collections.abc import Callable, Iterable
import json import json
from typing import Any from typing import Any
from sqlalchemy import Column, not_, or_ from sqlalchemy import JSON, Column, Text, cast, not_, or_
from sqlalchemy.sql.elements import ClauseList from sqlalchemy.sql.elements import ClauseList
from homeassistant.const import CONF_DOMAINS, CONF_ENTITIES, CONF_EXCLUDE, CONF_INCLUDE from homeassistant.const import CONF_DOMAINS, CONF_ENTITIES, CONF_EXCLUDE, CONF_INCLUDE
@ -110,8 +110,7 @@ class Filters:
"""Generate the entity filter query.""" """Generate the entity filter query."""
_encoder = json.dumps _encoder = json.dumps
return or_( return or_(
(ENTITY_ID_IN_EVENT == _encoder(None)) (ENTITY_ID_IN_EVENT == JSON.NULL) & (OLD_ENTITY_ID_IN_EVENT == JSON.NULL),
& (OLD_ENTITY_ID_IN_EVENT == _encoder(None)),
self._generate_filter_for_columns( self._generate_filter_for_columns(
(ENTITY_ID_IN_EVENT, OLD_ENTITY_ID_IN_EVENT), _encoder (ENTITY_ID_IN_EVENT, OLD_ENTITY_ID_IN_EVENT), _encoder
).self_group(), ).self_group(),
@ -123,7 +122,7 @@ def _globs_to_like(
) -> ClauseList: ) -> ClauseList:
"""Translate glob to sql.""" """Translate glob to sql."""
return or_( return or_(
column.like(encoder(glob_str.translate(GLOB_TO_SQL_CHARS))) cast(column, Text()).like(encoder(glob_str.translate(GLOB_TO_SQL_CHARS)))
for glob_str in glob_strs for glob_str in glob_strs
for column in columns for column in columns
) )
@ -133,7 +132,7 @@ def _entity_matcher(
entity_ids: Iterable[str], columns: Iterable[Column], encoder: Callable[[Any], Any] entity_ids: Iterable[str], columns: Iterable[Column], encoder: Callable[[Any], Any]
) -> ClauseList: ) -> ClauseList:
return or_( return or_(
column.in_([encoder(entity_id) for entity_id in entity_ids]) cast(column, Text()).in_([encoder(entity_id) for entity_id in entity_ids])
for column in columns for column in columns
) )
@ -142,5 +141,7 @@ def _domain_matcher(
domains: Iterable[str], columns: Iterable[Column], encoder: Callable[[Any], Any] domains: Iterable[str], columns: Iterable[Column], encoder: Callable[[Any], Any]
) -> ClauseList: ) -> ClauseList:
return or_( return or_(
column.like(encoder(f"{domain}.%")) for domain in domains for column in columns cast(column, Text()).like(encoder(f"{domain}.%"))
for domain in domains
for column in columns
) )