From 922d4c42a3e05c611e7d1cdd04278746138bdc00 Mon Sep 17 00:00:00 2001 From: Diogo Gomes Date: Tue, 28 Sep 2021 08:30:21 +0100 Subject: [PATCH] Inherit Filter sensor state_class from source sensor (#56407) --- homeassistant/components/filter/sensor.py | 5 +++++ tests/components/filter/test_sensor.py | 22 ++++++++++++++++++---- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/homeassistant/components/filter/sensor.py b/homeassistant/components/filter/sensor.py index f9705887549..665ef6b6ecd 100644 --- a/homeassistant/components/filter/sensor.py +++ b/homeassistant/components/filter/sensor.py @@ -15,6 +15,7 @@ from homeassistant.components.binary_sensor import DOMAIN as BINARY_SENSOR_DOMAI from homeassistant.components.input_number import DOMAIN as INPUT_NUMBER_DOMAIN from homeassistant.components.recorder import history from homeassistant.components.sensor import ( + ATTR_STATE_CLASS, DEVICE_CLASSES as SENSOR_DEVICE_CLASSES, DOMAIN as SENSOR_DOMAIN, PLATFORM_SCHEMA, @@ -191,6 +192,7 @@ class SensorFilter(SensorEntity): self._filters = filters self._icon = None self._device_class = None + self._attr_state_class = None @callback def _update_filter_sensor_state_event(self, event): @@ -248,6 +250,9 @@ class SensorFilter(SensorEntity): ): self._device_class = new_state.attributes.get(ATTR_DEVICE_CLASS) + if self._attr_state_class is None: + self._attr_state_class = new_state.attributes.get(ATTR_STATE_CLASS) + if self._unit_of_measurement is None: self._unit_of_measurement = new_state.attributes.get( ATTR_UNIT_OF_MEASUREMENT diff --git a/tests/components/filter/test_sensor.py b/tests/components/filter/test_sensor.py index 60fae0fc5be..89e8758c661 100644 --- a/tests/components/filter/test_sensor.py +++ b/tests/components/filter/test_sensor.py @@ -15,8 +15,17 @@ from homeassistant.components.filter.sensor import ( TimeSMAFilter, TimeThrottleFilter, ) -from homeassistant.components.sensor import DEVICE_CLASS_TEMPERATURE -from homeassistant.const import SERVICE_RELOAD, STATE_UNAVAILABLE, STATE_UNKNOWN +from homeassistant.components.sensor import ( + ATTR_STATE_CLASS, + DEVICE_CLASS_TEMPERATURE, + STATE_CLASS_TOTAL_INCREASING, +) +from homeassistant.const import ( + ATTR_DEVICE_CLASS, + SERVICE_RELOAD, + STATE_UNAVAILABLE, + STATE_UNKNOWN, +) import homeassistant.core as ha from homeassistant.setup import async_setup_component import homeassistant.util.dt as dt_util @@ -264,12 +273,17 @@ async def test_setup(hass): hass.states.async_set( "sensor.test_monitored", 1, - {"icon": "mdi:test", "device_class": DEVICE_CLASS_TEMPERATURE}, + { + "icon": "mdi:test", + ATTR_DEVICE_CLASS: DEVICE_CLASS_TEMPERATURE, + ATTR_STATE_CLASS: STATE_CLASS_TOTAL_INCREASING, + }, ) await hass.async_block_till_done() state = hass.states.get("sensor.test") assert state.attributes["icon"] == "mdi:test" - assert state.attributes["device_class"] == DEVICE_CLASS_TEMPERATURE + assert state.attributes[ATTR_DEVICE_CLASS] == DEVICE_CLASS_TEMPERATURE + assert state.attributes[ATTR_STATE_CLASS] == STATE_CLASS_TOTAL_INCREASING assert state.state == "1.0"