diff --git a/homeassistant/components/derivative/sensor.py b/homeassistant/components/derivative/sensor.py index e1cc278137c..793e8edc769 100644 --- a/homeassistant/components/derivative/sensor.py +++ b/homeassistant/components/derivative/sensor.py @@ -19,7 +19,12 @@ from homeassistant.const import ( UnitOfTime, ) from homeassistant.core import Event, HomeAssistant, State, callback -from homeassistant.helpers import config_validation as cv, entity_registry as er +from homeassistant.helpers import ( + config_validation as cv, + device_registry as dr, + entity_registry as er, +) +from homeassistant.helpers.entity import DeviceInfo from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.event import async_track_state_change_event from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType @@ -86,6 +91,27 @@ async def async_setup_entry( registry, config_entry.options[CONF_SOURCE] ) + source_entity = registry.async_get(source_entity_id) + dev_reg = dr.async_get(hass) + # Resolve source entity device + if ( + (source_entity is not None) + and (source_entity.device_id is not None) + and ( + ( + device := dev_reg.async_get( + device_id=source_entity.device_id, + ) + ) + is not None + ) + ): + device_info = DeviceInfo( + identifiers=device.identifiers, + ) + else: + device_info = None + unit_prefix = config_entry.options[CONF_UNIT_PREFIX] if unit_prefix == "none": unit_prefix = None @@ -99,6 +125,7 @@ async def async_setup_entry( unit_of_measurement=None, unit_prefix=unit_prefix, unit_time=config_entry.options[CONF_UNIT_TIME], + device_info=device_info, ) async_add_entities([derivative_sensor]) @@ -142,9 +169,11 @@ class DerivativeSensor(RestoreSensor, SensorEntity): unit_prefix: str | None, unit_time: UnitOfTime, unique_id: str | None, + device_info: DeviceInfo | None = None, ) -> None: """Initialize the derivative sensor.""" self._attr_unique_id = unique_id + self._attr_device_info = device_info self._sensor_source_id = source_entity self._round_digits = round_digits self._state: float | int | Decimal = 0 diff --git a/tests/components/derivative/test_sensor.py b/tests/components/derivative/test_sensor.py index 8260e5a0ada..513e9597572 100644 --- a/tests/components/derivative/test_sensor.py +++ b/tests/components/derivative/test_sensor.py @@ -5,11 +5,15 @@ import random from freezegun import freeze_time +from homeassistant.components.derivative.const import DOMAIN from homeassistant.const import UnitOfPower, UnitOfTime from homeassistant.core import HomeAssistant +from homeassistant.helpers import device_registry as dr, entity_registry as er from homeassistant.setup import async_setup_component import homeassistant.util.dt as dt_util +from tests.common import MockConfigEntry + async def test_state(hass: HomeAssistant) -> None: """Test derivative sensor state.""" @@ -342,3 +346,47 @@ async def test_suffix(hass: HomeAssistant) -> None: # Testing a network speed sensor at 1000 bytes/s over 10s = 10kbytes/s2 assert round(float(state.state), config["sensor"]["round"]) == 0.0 + + +async def test_device_id(hass: HomeAssistant) -> None: + """Test for source entity device for Derivative.""" + device_registry = dr.async_get(hass) + entity_registry = er.async_get(hass) + + source_config_entry = MockConfigEntry() + source_device_entry = device_registry.async_get_or_create( + config_entry_id=source_config_entry.entry_id, + identifiers={("sensor", "identifier_test")}, + ) + source_entity = entity_registry.async_get_or_create( + "sensor", + "test", + "source", + config_entry=source_config_entry, + device_id=source_device_entry.id, + ) + await hass.async_block_till_done() + assert entity_registry.async_get("sensor.test_source") is not None + + derivative_config_entry = MockConfigEntry( + data={}, + domain=DOMAIN, + options={ + "name": "Derivative", + "round": 1.0, + "source": "sensor.test_source", + "time_window": {"seconds": 0.0}, + "unit_prefix": "k", + "unit_time": "min", + }, + title="Derivative", + ) + + derivative_config_entry.add_to_hass(hass) + + assert await hass.config_entries.async_setup(derivative_config_entry.entry_id) + await hass.async_block_till_done() + + derivative_entity = entity_registry.async_get("sensor.derivative") + assert derivative_entity is not None + assert derivative_entity.device_id == source_entity.device_id