Make recorder platform attribute exclude integration aware (#88357)

This commit is contained in:
J. Nick Koston 2023-02-18 03:08:59 -06:00 committed by GitHub
parent 97d9951d8a
commit 289bab6f87
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 65 additions and 7 deletions

View File

@ -30,6 +30,7 @@ from homeassistant.const import (
MATCH_ALL,
)
from homeassistant.core import CALLBACK_TYPE, Event, HomeAssistant, callback
from homeassistant.helpers import entity_registry
from homeassistant.helpers.event import (
async_track_time_change,
async_track_time_interval,
@ -184,6 +185,7 @@ class Recorder(threading.Thread):
self._queue_watch = threading.Event()
self.engine: Engine | None = None
self.run_history = RunHistory()
self._entity_registry = entity_registry.async_get(hass)
# The entity_filter is exposed on the recorder instance so that
# it can be used to see if an entity is being recorded and is called
@ -875,7 +877,10 @@ class Recorder(threading.Thread):
try:
dbstate = States.from_event(event)
shared_attrs_bytes = StateAttributes.shared_attrs_bytes_from_event(
event, self._exclude_attributes_by_domain, self.dialect_name
event,
self._entity_registry,
self._exclude_attributes_by_domain,
self.dialect_name,
)
except JSON_ENCODE_EXCEPTIONS as ex:
_LOGGER.warning(

View File

@ -41,6 +41,7 @@ from homeassistant.const import (
MAX_LENGTH_STATE_STATE,
)
from homeassistant.core import Context, Event, EventOrigin, State, split_entity_id
from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.json import JSON_DUMP, json_bytes, json_bytes_strip_null
import homeassistant.util.dt as dt_util
from homeassistant.util.json import (
@ -459,6 +460,7 @@ class StateAttributes(Base):
@staticmethod
def shared_attrs_bytes_from_event(
event: Event,
entity_registry: er.EntityRegistry,
exclude_attrs_by_domain: dict[str, set[str]],
dialect: SupportedDialect | None,
) -> bytes:
@ -468,9 +470,13 @@ class StateAttributes(Base):
if state is None:
return b"{}"
domain = split_entity_id(state.entity_id)[0]
exclude_attrs = (
exclude_attrs_by_domain.get(domain, set()) | ALL_DOMAIN_EXCLUDE_ATTRS
)
exclude_attrs = set(ALL_DOMAIN_EXCLUDE_ATTRS)
if base_platform_attrs := exclude_attrs_by_domain.get(domain):
exclude_attrs |= base_platform_attrs
if (reg_ent := entity_registry.async_get(state.entity_id)) and (
integration_attrs := exclude_attrs_by_domain.get(reg_ent.platform)
):
exclude_attrs |= integration_attrs
encoder = json_bytes_strip_null if dialect == PSQL_DIALECT else json_bytes
bytes_result = encoder(
{k: v for k, v in state.attributes.items() if k not in exclude_attrs}

View File

@ -46,6 +46,7 @@ from homeassistant.const import (
MAX_LENGTH_STATE_STATE,
)
from homeassistant.core import Context, Event, EventOrigin, State, split_entity_id
from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.json import JSON_DUMP, json_bytes
import homeassistant.util.dt as dt_util
from homeassistant.util.json import JSON_DECODE_EXCEPTIONS, json_loads
@ -436,6 +437,7 @@ class StateAttributes(Base): # type: ignore[misc,valid-type]
@staticmethod
def shared_attrs_bytes_from_event(
event: Event,
entity_registry: er.EntityRegistry,
exclude_attrs_by_domain: dict[str, set[str]],
dialect: SupportedDialect | None,
) -> bytes:

View File

@ -53,6 +53,7 @@ from homeassistant.components.recorder.services import (
)
from homeassistant.components.recorder.util import session_scope
from homeassistant.const import (
EVENT_COMPONENT_LOADED,
EVENT_HOMEASSISTANT_FINAL_WRITE,
EVENT_HOMEASSISTANT_STARTED,
EVENT_HOMEASSISTANT_STOP,
@ -61,7 +62,7 @@ from homeassistant.const import (
STATE_UNLOCKED,
)
from homeassistant.core import CoreState, Event, HomeAssistant, callback
from homeassistant.helpers import recorder as recorder_helper
from homeassistant.helpers import entity_registry as er, recorder as recorder_helper
from homeassistant.setup import async_setup_component, setup_component
from homeassistant.util import dt as dt_util
@ -77,6 +78,7 @@ from tests.common import (
async_fire_time_changed,
fire_time_changed,
get_test_home_assistant,
mock_platform,
)
from tests.typing import RecorderInstanceGenerator
@ -2027,3 +2029,44 @@ async def test_connect_args_priority(hass: HomeAssistant, config_url) -> None:
},
)
assert connect_params[0]["charset"] == "utf8mb4"
async def test_excluding_attributes_by_integration(
recorder_mock: Recorder, hass: HomeAssistant, entity_registry: er.EntityRegistry
) -> None:
"""Test that an integration's recorder platform can exclude attributes."""
state = "restoring_from_db"
attributes = {"test_attr": 5, "excluded": 10}
entry = entity_registry.async_get_or_create(
"test",
"fake_integration",
"recorder",
)
entity_id = entry.entity_id
mock_platform(
hass,
"fake_integration.recorder",
Mock(exclude_attributes=lambda hass: {"excluded"}),
)
hass.config.components.add("fake_integration")
hass.bus.async_fire(EVENT_COMPONENT_LOADED, {"component": "fake_integration"})
await hass.async_block_till_done()
hass.states.async_set(entity_id, state, attributes)
await async_wait_recording_done(hass)
with session_scope(hass=hass) as session:
db_states = []
for db_state, db_state_attributes in session.query(
States, StateAttributes
).outerjoin(
StateAttributes, States.attributes_id == StateAttributes.attributes_id
):
db_states.append(db_state)
state = db_state.to_native()
state.attributes = db_state_attributes.to_native()
assert len(db_states) == 1
assert db_states[0].event_id is None
expected = _state_with_context(hass, entity_id)
expected.attributes = {"test_attr": 5}
assert state.as_dict() == expected.as_dict()

View File

@ -26,6 +26,7 @@ from homeassistant.const import EVENT_STATE_CHANGED
import homeassistant.core as ha
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import InvalidEntityFormatError
from homeassistant.helpers import entity_registry as er
from homeassistant.util import dt, dt as dt_util
@ -49,7 +50,7 @@ def test_from_event_to_db_state() -> None:
assert state.as_dict() == States.from_event(event).to_native().as_dict()
def test_from_event_to_db_state_attributes() -> None:
def test_from_event_to_db_state_attributes(entity_registry: er.EntityRegistry) -> None:
"""Test converting event to db state attributes."""
attrs = {"this_attr": True}
state = ha.State("sensor.temperature", "18", attrs)
@ -60,8 +61,9 @@ def test_from_event_to_db_state_attributes() -> None:
)
db_attrs = StateAttributes()
dialect = SupportedDialect.MYSQL
db_attrs.shared_attrs = StateAttributes.shared_attrs_bytes_from_event(
event, {}, dialect
event, entity_registry, {}, dialect
)
assert db_attrs.to_native() == attrs