diff --git a/homeassistant/components/recorder/core.py b/homeassistant/components/recorder/core.py index 2604497eff4..a6cd56af733 100644 --- a/homeassistant/components/recorder/core.py +++ b/homeassistant/components/recorder/core.py @@ -30,6 +30,7 @@ from homeassistant.const import ( MATCH_ALL, ) from homeassistant.core import CALLBACK_TYPE, Event, HomeAssistant, callback +from homeassistant.helpers import entity_registry from homeassistant.helpers.event import ( async_track_time_change, async_track_time_interval, @@ -184,6 +185,7 @@ class Recorder(threading.Thread): self._queue_watch = threading.Event() self.engine: Engine | None = None self.run_history = RunHistory() + self._entity_registry = entity_registry.async_get(hass) # The entity_filter is exposed on the recorder instance so that # it can be used to see if an entity is being recorded and is called @@ -875,7 +877,10 @@ class Recorder(threading.Thread): try: dbstate = States.from_event(event) shared_attrs_bytes = StateAttributes.shared_attrs_bytes_from_event( - event, self._exclude_attributes_by_domain, self.dialect_name + event, + self._entity_registry, + self._exclude_attributes_by_domain, + self.dialect_name, ) except JSON_ENCODE_EXCEPTIONS as ex: _LOGGER.warning( diff --git a/homeassistant/components/recorder/db_schema.py b/homeassistant/components/recorder/db_schema.py index c97f99b9e8c..88a8478047f 100644 --- a/homeassistant/components/recorder/db_schema.py +++ b/homeassistant/components/recorder/db_schema.py @@ -41,6 +41,7 @@ from homeassistant.const import ( MAX_LENGTH_STATE_STATE, ) from homeassistant.core import Context, Event, EventOrigin, State, split_entity_id +from homeassistant.helpers import entity_registry as er from homeassistant.helpers.json import JSON_DUMP, json_bytes, json_bytes_strip_null import homeassistant.util.dt as dt_util from homeassistant.util.json import ( @@ -459,6 +460,7 @@ class StateAttributes(Base): @staticmethod def shared_attrs_bytes_from_event( event: Event, + entity_registry: er.EntityRegistry, exclude_attrs_by_domain: dict[str, set[str]], dialect: SupportedDialect | None, ) -> bytes: @@ -468,9 +470,13 @@ class StateAttributes(Base): if state is None: return b"{}" domain = split_entity_id(state.entity_id)[0] - exclude_attrs = ( - exclude_attrs_by_domain.get(domain, set()) | ALL_DOMAIN_EXCLUDE_ATTRS - ) + exclude_attrs = set(ALL_DOMAIN_EXCLUDE_ATTRS) + if base_platform_attrs := exclude_attrs_by_domain.get(domain): + exclude_attrs |= base_platform_attrs + if (reg_ent := entity_registry.async_get(state.entity_id)) and ( + integration_attrs := exclude_attrs_by_domain.get(reg_ent.platform) + ): + exclude_attrs |= integration_attrs encoder = json_bytes_strip_null if dialect == PSQL_DIALECT else json_bytes bytes_result = encoder( {k: v for k, v in state.attributes.items() if k not in exclude_attrs} diff --git a/tests/components/recorder/db_schema_30.py b/tests/components/recorder/db_schema_30.py index e5408662e5f..91f7593969a 100644 --- a/tests/components/recorder/db_schema_30.py +++ b/tests/components/recorder/db_schema_30.py @@ -46,6 +46,7 @@ from homeassistant.const import ( MAX_LENGTH_STATE_STATE, ) from homeassistant.core import Context, Event, EventOrigin, State, split_entity_id +from homeassistant.helpers import entity_registry as er from homeassistant.helpers.json import JSON_DUMP, json_bytes import homeassistant.util.dt as dt_util from homeassistant.util.json import JSON_DECODE_EXCEPTIONS, json_loads @@ -436,6 +437,7 @@ class StateAttributes(Base): # type: ignore[misc,valid-type] @staticmethod def shared_attrs_bytes_from_event( event: Event, + entity_registry: er.EntityRegistry, exclude_attrs_by_domain: dict[str, set[str]], dialect: SupportedDialect | None, ) -> bytes: diff --git a/tests/components/recorder/test_init.py b/tests/components/recorder/test_init.py index 0da1799b8bd..7d188f982f1 100644 --- a/tests/components/recorder/test_init.py +++ b/tests/components/recorder/test_init.py @@ -53,6 +53,7 @@ from homeassistant.components.recorder.services import ( ) from homeassistant.components.recorder.util import session_scope from homeassistant.const import ( + EVENT_COMPONENT_LOADED, EVENT_HOMEASSISTANT_FINAL_WRITE, EVENT_HOMEASSISTANT_STARTED, EVENT_HOMEASSISTANT_STOP, @@ -61,7 +62,7 @@ from homeassistant.const import ( STATE_UNLOCKED, ) from homeassistant.core import CoreState, Event, HomeAssistant, callback -from homeassistant.helpers import recorder as recorder_helper +from homeassistant.helpers import entity_registry as er, recorder as recorder_helper from homeassistant.setup import async_setup_component, setup_component from homeassistant.util import dt as dt_util @@ -77,6 +78,7 @@ from tests.common import ( async_fire_time_changed, fire_time_changed, get_test_home_assistant, + mock_platform, ) from tests.typing import RecorderInstanceGenerator @@ -2027,3 +2029,44 @@ async def test_connect_args_priority(hass: HomeAssistant, config_url) -> None: }, ) assert connect_params[0]["charset"] == "utf8mb4" + + +async def test_excluding_attributes_by_integration( + recorder_mock: Recorder, hass: HomeAssistant, entity_registry: er.EntityRegistry +) -> None: + """Test that an integration's recorder platform can exclude attributes.""" + state = "restoring_from_db" + attributes = {"test_attr": 5, "excluded": 10} + entry = entity_registry.async_get_or_create( + "test", + "fake_integration", + "recorder", + ) + entity_id = entry.entity_id + mock_platform( + hass, + "fake_integration.recorder", + Mock(exclude_attributes=lambda hass: {"excluded"}), + ) + hass.config.components.add("fake_integration") + hass.bus.async_fire(EVENT_COMPONENT_LOADED, {"component": "fake_integration"}) + await hass.async_block_till_done() + hass.states.async_set(entity_id, state, attributes) + await async_wait_recording_done(hass) + + with session_scope(hass=hass) as session: + db_states = [] + for db_state, db_state_attributes in session.query( + States, StateAttributes + ).outerjoin( + StateAttributes, States.attributes_id == StateAttributes.attributes_id + ): + db_states.append(db_state) + state = db_state.to_native() + state.attributes = db_state_attributes.to_native() + assert len(db_states) == 1 + assert db_states[0].event_id is None + + expected = _state_with_context(hass, entity_id) + expected.attributes = {"test_attr": 5} + assert state.as_dict() == expected.as_dict() diff --git a/tests/components/recorder/test_models.py b/tests/components/recorder/test_models.py index 5934e37c583..8089ea1ed7c 100644 --- a/tests/components/recorder/test_models.py +++ b/tests/components/recorder/test_models.py @@ -26,6 +26,7 @@ from homeassistant.const import EVENT_STATE_CHANGED import homeassistant.core as ha from homeassistant.core import HomeAssistant from homeassistant.exceptions import InvalidEntityFormatError +from homeassistant.helpers import entity_registry as er from homeassistant.util import dt, dt as dt_util @@ -49,7 +50,7 @@ def test_from_event_to_db_state() -> None: assert state.as_dict() == States.from_event(event).to_native().as_dict() -def test_from_event_to_db_state_attributes() -> None: +def test_from_event_to_db_state_attributes(entity_registry: er.EntityRegistry) -> None: """Test converting event to db state attributes.""" attrs = {"this_attr": True} state = ha.State("sensor.temperature", "18", attrs) @@ -60,8 +61,9 @@ def test_from_event_to_db_state_attributes() -> None: ) db_attrs = StateAttributes() dialect = SupportedDialect.MYSQL + db_attrs.shared_attrs = StateAttributes.shared_attrs_bytes_from_event( - event, {}, dialect + event, entity_registry, {}, dialect ) assert db_attrs.to_native() == attrs