Handle different entity_id formats (#49969)

This commit is contained in:
Paulus Schoutsen 2021-05-01 20:30:28 -07:00 committed by GitHub
parent 3546ff2da2
commit 1bd9826684
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 51 additions and 2 deletions

View File

@ -358,13 +358,27 @@ class Recorder(threading.Thread):
self._event_listener = None
@callback
def _async_event_filter(self, event):
def _async_event_filter(self, event) -> bool:
"""Filter events."""
if event.event_type in self.exclude_t:
return False
entity_id = event.data.get(ATTR_ENTITY_ID)
return bool(entity_id is None or self.entity_filter(entity_id))
if entity_id is None:
return True
if isinstance(entity_id, str):
return self.entity_filter(entity_id)
if isinstance(entity_id, list):
for eid in entity_id:
if self.entity_filter(eid):
return True
return False
# Unknown what it is.
return True
def do_adhoc_purge(self, **kwargs):
"""Trigger an adhoc purge retaining keep_days worth of data."""

View File

@ -931,3 +931,38 @@ async def test_database_corruption_while_running(hass, tmpdir, caplog):
hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP)
await hass.async_block_till_done()
hass.stop()
def test_entity_id_filter(hass_recorder):
"""Test that entity ID filtering filters string and list."""
hass = hass_recorder(
{"include": {"domains": "hello"}, "exclude": {"domains": "hidden_domain"}}
)
for idx, data in enumerate(
(
{},
{"entity_id": "hello.world"},
{"entity_id": ["hello.world"]},
{"entity_id": ["hello.world", "hidden_domain.person"]},
{"entity_id": {"unexpected": "data"}},
)
):
hass.bus.fire("hello", data)
wait_recording_done(hass)
with session_scope(hass=hass) as session:
db_events = list(session.query(Events).filter_by(event_type="hello"))
assert len(db_events) == idx + 1, data
for data in (
{"entity_id": "hidden_domain.person"},
{"entity_id": ["hidden_domain.person"]},
):
hass.bus.fire("hello", data)
wait_recording_done(hass)
with session_scope(hass=hass) as session:
db_events = list(session.query(Events).filter_by(event_type="hello"))
# Keep referring idx + 1, as no new events are being added
assert len(db_events) == idx + 1, data