From c814f4f307f33974b82939d4e1fb1a384f3ef0d9 Mon Sep 17 00:00:00 2001 From: Michael <35783820+mib1185@users.noreply.github.com> Date: Fri, 7 Feb 2025 20:11:04 +0100 Subject: [PATCH] Explicitly pass in the config entry in amberelectric coordinator init (#137700) * explicitly pass in the config_entry in amberelectric coordinator init * fix amberelectric tests --- .../components/amberelectric/__init__.py | 7 +--- .../components/amberelectric/binary_sensor.py | 3 +- .../components/amberelectric/coordinator.py | 12 +++++- .../components/amberelectric/sensor.py | 3 +- .../amberelectric/test_coordinator.py | 37 +++++++++++++++---- 5 files changed, 45 insertions(+), 17 deletions(-) diff --git a/homeassistant/components/amberelectric/__init__.py b/homeassistant/components/amberelectric/__init__.py index 29d8f166f2a..9eab6f42ad3 100644 --- a/homeassistant/components/amberelectric/__init__.py +++ b/homeassistant/components/amberelectric/__init__.py @@ -2,14 +2,11 @@ import amberelectric -from homeassistant.config_entries import ConfigEntry from homeassistant.const import CONF_API_TOKEN from homeassistant.core import HomeAssistant from .const import CONF_SITE_ID, PLATFORMS -from .coordinator import AmberUpdateCoordinator - -type AmberConfigEntry = ConfigEntry[AmberUpdateCoordinator] +from .coordinator import AmberConfigEntry, AmberUpdateCoordinator async def async_setup_entry(hass: HomeAssistant, entry: AmberConfigEntry) -> bool: @@ -19,7 +16,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: AmberConfigEntry) -> boo api_instance = amberelectric.AmberApi(api_client) site_id = entry.data[CONF_SITE_ID] - coordinator = AmberUpdateCoordinator(hass, api_instance, site_id) + coordinator = AmberUpdateCoordinator(hass, entry, api_instance, site_id) await coordinator.async_config_entry_first_refresh() entry.runtime_data = coordinator await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) diff --git a/homeassistant/components/amberelectric/binary_sensor.py b/homeassistant/components/amberelectric/binary_sensor.py index a9fa00d0129..66292ea5524 100644 --- a/homeassistant/components/amberelectric/binary_sensor.py +++ b/homeassistant/components/amberelectric/binary_sensor.py @@ -12,9 +12,8 @@ from homeassistant.core import HomeAssistant from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.update_coordinator import CoordinatorEntity -from . import AmberConfigEntry from .const import ATTRIBUTION -from .coordinator import AmberUpdateCoordinator +from .coordinator import AmberConfigEntry, AmberUpdateCoordinator PRICE_SPIKE_ICONS = { "none": "mdi:power-plug", diff --git a/homeassistant/components/amberelectric/coordinator.py b/homeassistant/components/amberelectric/coordinator.py index 57028e07d21..1edf64ba0d6 100644 --- a/homeassistant/components/amberelectric/coordinator.py +++ b/homeassistant/components/amberelectric/coordinator.py @@ -13,11 +13,14 @@ from amberelectric.models.forecast_interval import ForecastInterval from amberelectric.models.price_descriptor import PriceDescriptor from amberelectric.rest import ApiException +from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed from .const import LOGGER +type AmberConfigEntry = ConfigEntry[AmberUpdateCoordinator] + def is_current(interval: ActualInterval | CurrentInterval | ForecastInterval) -> bool: """Return true if the supplied interval is a CurrentInterval.""" @@ -70,13 +73,20 @@ def normalize_descriptor(descriptor: PriceDescriptor | None) -> str | None: class AmberUpdateCoordinator(DataUpdateCoordinator): """AmberUpdateCoordinator - In charge of downloading the data for a site, which all the sensors read.""" + config_entry: AmberConfigEntry + def __init__( - self, hass: HomeAssistant, api: amberelectric.AmberApi, site_id: str + self, + hass: HomeAssistant, + config_entry: AmberConfigEntry, + api: amberelectric.AmberApi, + site_id: str, ) -> None: """Initialise the data service.""" super().__init__( hass, LOGGER, + config_entry=config_entry, name="amberelectric", update_interval=timedelta(minutes=1), ) diff --git a/homeassistant/components/amberelectric/sensor.py b/homeassistant/components/amberelectric/sensor.py index cdf40e5804d..49d6e5f4eac 100644 --- a/homeassistant/components/amberelectric/sensor.py +++ b/homeassistant/components/amberelectric/sensor.py @@ -22,9 +22,8 @@ from homeassistant.core import HomeAssistant from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.update_coordinator import CoordinatorEntity -from . import AmberConfigEntry from .const import ATTRIBUTION -from .coordinator import AmberUpdateCoordinator, normalize_descriptor +from .coordinator import AmberConfigEntry, AmberUpdateCoordinator, normalize_descriptor UNIT = f"{CURRENCY_DOLLAR}/{UnitOfEnergy.KILO_WATT_HOUR}" diff --git a/tests/components/amberelectric/test_coordinator.py b/tests/components/amberelectric/test_coordinator.py index 0a8f5b874fa..6faabc924b4 100644 --- a/tests/components/amberelectric/test_coordinator.py +++ b/tests/components/amberelectric/test_coordinator.py @@ -16,10 +16,12 @@ from amberelectric.models.spike_status import SpikeStatus from dateutil import parser import pytest +from homeassistant.components.amberelectric.const import CONF_SITE_ID, CONF_SITE_NAME from homeassistant.components.amberelectric.coordinator import ( AmberUpdateCoordinator, normalize_descriptor, ) +from homeassistant.const import CONF_API_TOKEN from homeassistant.core import HomeAssistant from homeassistant.helpers.update_coordinator import UpdateFailed @@ -33,6 +35,17 @@ from .helpers import ( generate_current_interval, ) +from tests.common import MockConfigEntry + +MOCKED_ENTRY = MockConfigEntry( + domain="amberelectric", + data={ + CONF_SITE_NAME: "mock_title", + CONF_API_TOKEN: "psk_0000000000000000", + CONF_SITE_ID: GENERAL_ONLY_SITE_ID, + }, +) + @pytest.fixture(name="current_price_api") def mock_api_current_price() -> Generator: @@ -101,7 +114,9 @@ async def test_fetch_general_site(hass: HomeAssistant, current_price_api: Mock) """Test fetching a site with only a general channel.""" current_price_api.get_current_prices.return_value = GENERAL_CHANNEL - data_service = AmberUpdateCoordinator(hass, current_price_api, GENERAL_ONLY_SITE_ID) + data_service = AmberUpdateCoordinator( + hass, MOCKED_ENTRY, current_price_api, GENERAL_ONLY_SITE_ID + ) result = await data_service._async_update_data() current_price_api.get_current_prices.assert_called_with( @@ -130,7 +145,9 @@ async def test_fetch_no_general_site( """Test fetching a site with no general channel.""" current_price_api.get_current_prices.return_value = CONTROLLED_LOAD_CHANNEL - data_service = AmberUpdateCoordinator(hass, current_price_api, GENERAL_ONLY_SITE_ID) + data_service = AmberUpdateCoordinator( + hass, MOCKED_ENTRY, current_price_api, GENERAL_ONLY_SITE_ID + ) with pytest.raises(UpdateFailed): await data_service._async_update_data() @@ -143,7 +160,9 @@ async def test_fetch_api_error(hass: HomeAssistant, current_price_api: Mock) -> """Test that the old values are maintained if a second call fails.""" current_price_api.get_current_prices.return_value = GENERAL_CHANNEL - data_service = AmberUpdateCoordinator(hass, current_price_api, GENERAL_ONLY_SITE_ID) + data_service = AmberUpdateCoordinator( + hass, MOCKED_ENTRY, current_price_api, GENERAL_ONLY_SITE_ID + ) result = await data_service._async_update_data() current_price_api.get_current_prices.assert_called_with( @@ -193,7 +212,7 @@ async def test_fetch_general_and_controlled_load_site( GENERAL_CHANNEL + CONTROLLED_LOAD_CHANNEL ) data_service = AmberUpdateCoordinator( - hass, current_price_api, GENERAL_AND_CONTROLLED_SITE_ID + hass, MOCKED_ENTRY, current_price_api, GENERAL_AND_CONTROLLED_SITE_ID ) result = await data_service._async_update_data() @@ -233,7 +252,7 @@ async def test_fetch_general_and_feed_in_site( GENERAL_CHANNEL + FEED_IN_CHANNEL ) data_service = AmberUpdateCoordinator( - hass, current_price_api, GENERAL_AND_FEED_IN_SITE_ID + hass, MOCKED_ENTRY, current_price_api, GENERAL_AND_FEED_IN_SITE_ID ) result = await data_service._async_update_data() @@ -273,7 +292,9 @@ async def test_fetch_potential_spike( ] general_channel[0].actual_instance.spike_status = SpikeStatus.POTENTIAL current_price_api.get_current_prices.return_value = general_channel - data_service = AmberUpdateCoordinator(hass, current_price_api, GENERAL_ONLY_SITE_ID) + data_service = AmberUpdateCoordinator( + hass, MOCKED_ENTRY, current_price_api, GENERAL_ONLY_SITE_ID + ) result = await data_service._async_update_data() assert result["grid"]["price_spike"] == "potential" @@ -288,6 +309,8 @@ async def test_fetch_spike(hass: HomeAssistant, current_price_api: Mock) -> None ] general_channel[0].actual_instance.spike_status = SpikeStatus.SPIKE current_price_api.get_current_prices.return_value = general_channel - data_service = AmberUpdateCoordinator(hass, current_price_api, GENERAL_ONLY_SITE_ID) + data_service = AmberUpdateCoordinator( + hass, MOCKED_ENTRY, current_price_api, GENERAL_ONLY_SITE_ID + ) result = await data_service._async_update_data() assert result["grid"]["price_spike"] == "spike"