Explicitly pass in the config entry in amberelectric coordinator init (#137700)

* explicitly pass in the config_entry in amberelectric coordinator init

* fix amberelectric tests
This commit is contained in:
Michael 2025-02-07 20:11:04 +01:00 committed by GitHub
parent 1ff9ec661c
commit c814f4f307
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 45 additions and 17 deletions

View File

@ -2,14 +2,11 @@
import amberelectric
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_API_TOKEN
from homeassistant.core import HomeAssistant
from .const import CONF_SITE_ID, PLATFORMS
from .coordinator import AmberUpdateCoordinator
type AmberConfigEntry = ConfigEntry[AmberUpdateCoordinator]
from .coordinator import AmberConfigEntry, AmberUpdateCoordinator
async def async_setup_entry(hass: HomeAssistant, entry: AmberConfigEntry) -> bool:
@ -19,7 +16,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: AmberConfigEntry) -> boo
api_instance = amberelectric.AmberApi(api_client)
site_id = entry.data[CONF_SITE_ID]
coordinator = AmberUpdateCoordinator(hass, api_instance, site_id)
coordinator = AmberUpdateCoordinator(hass, entry, api_instance, site_id)
await coordinator.async_config_entry_first_refresh()
entry.runtime_data = coordinator
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)

View File

@ -12,9 +12,8 @@ from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.update_coordinator import CoordinatorEntity
from . import AmberConfigEntry
from .const import ATTRIBUTION
from .coordinator import AmberUpdateCoordinator
from .coordinator import AmberConfigEntry, AmberUpdateCoordinator
PRICE_SPIKE_ICONS = {
"none": "mdi:power-plug",

View File

@ -13,11 +13,14 @@ from amberelectric.models.forecast_interval import ForecastInterval
from amberelectric.models.price_descriptor import PriceDescriptor
from amberelectric.rest import ApiException
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed
from .const import LOGGER
type AmberConfigEntry = ConfigEntry[AmberUpdateCoordinator]
def is_current(interval: ActualInterval | CurrentInterval | ForecastInterval) -> bool:
"""Return true if the supplied interval is a CurrentInterval."""
@ -70,13 +73,20 @@ def normalize_descriptor(descriptor: PriceDescriptor | None) -> str | None:
class AmberUpdateCoordinator(DataUpdateCoordinator):
"""AmberUpdateCoordinator - In charge of downloading the data for a site, which all the sensors read."""
config_entry: AmberConfigEntry
def __init__(
self, hass: HomeAssistant, api: amberelectric.AmberApi, site_id: str
self,
hass: HomeAssistant,
config_entry: AmberConfigEntry,
api: amberelectric.AmberApi,
site_id: str,
) -> None:
"""Initialise the data service."""
super().__init__(
hass,
LOGGER,
config_entry=config_entry,
name="amberelectric",
update_interval=timedelta(minutes=1),
)

View File

@ -22,9 +22,8 @@ from homeassistant.core import HomeAssistant
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.update_coordinator import CoordinatorEntity
from . import AmberConfigEntry
from .const import ATTRIBUTION
from .coordinator import AmberUpdateCoordinator, normalize_descriptor
from .coordinator import AmberConfigEntry, AmberUpdateCoordinator, normalize_descriptor
UNIT = f"{CURRENCY_DOLLAR}/{UnitOfEnergy.KILO_WATT_HOUR}"

View File

@ -16,10 +16,12 @@ from amberelectric.models.spike_status import SpikeStatus
from dateutil import parser
import pytest
from homeassistant.components.amberelectric.const import CONF_SITE_ID, CONF_SITE_NAME
from homeassistant.components.amberelectric.coordinator import (
AmberUpdateCoordinator,
normalize_descriptor,
)
from homeassistant.const import CONF_API_TOKEN
from homeassistant.core import HomeAssistant
from homeassistant.helpers.update_coordinator import UpdateFailed
@ -33,6 +35,17 @@ from .helpers import (
generate_current_interval,
)
from tests.common import MockConfigEntry
MOCKED_ENTRY = MockConfigEntry(
domain="amberelectric",
data={
CONF_SITE_NAME: "mock_title",
CONF_API_TOKEN: "psk_0000000000000000",
CONF_SITE_ID: GENERAL_ONLY_SITE_ID,
},
)
@pytest.fixture(name="current_price_api")
def mock_api_current_price() -> Generator:
@ -101,7 +114,9 @@ async def test_fetch_general_site(hass: HomeAssistant, current_price_api: Mock)
"""Test fetching a site with only a general channel."""
current_price_api.get_current_prices.return_value = GENERAL_CHANNEL
data_service = AmberUpdateCoordinator(hass, current_price_api, GENERAL_ONLY_SITE_ID)
data_service = AmberUpdateCoordinator(
hass, MOCKED_ENTRY, current_price_api, GENERAL_ONLY_SITE_ID
)
result = await data_service._async_update_data()
current_price_api.get_current_prices.assert_called_with(
@ -130,7 +145,9 @@ async def test_fetch_no_general_site(
"""Test fetching a site with no general channel."""
current_price_api.get_current_prices.return_value = CONTROLLED_LOAD_CHANNEL
data_service = AmberUpdateCoordinator(hass, current_price_api, GENERAL_ONLY_SITE_ID)
data_service = AmberUpdateCoordinator(
hass, MOCKED_ENTRY, current_price_api, GENERAL_ONLY_SITE_ID
)
with pytest.raises(UpdateFailed):
await data_service._async_update_data()
@ -143,7 +160,9 @@ async def test_fetch_api_error(hass: HomeAssistant, current_price_api: Mock) ->
"""Test that the old values are maintained if a second call fails."""
current_price_api.get_current_prices.return_value = GENERAL_CHANNEL
data_service = AmberUpdateCoordinator(hass, current_price_api, GENERAL_ONLY_SITE_ID)
data_service = AmberUpdateCoordinator(
hass, MOCKED_ENTRY, current_price_api, GENERAL_ONLY_SITE_ID
)
result = await data_service._async_update_data()
current_price_api.get_current_prices.assert_called_with(
@ -193,7 +212,7 @@ async def test_fetch_general_and_controlled_load_site(
GENERAL_CHANNEL + CONTROLLED_LOAD_CHANNEL
)
data_service = AmberUpdateCoordinator(
hass, current_price_api, GENERAL_AND_CONTROLLED_SITE_ID
hass, MOCKED_ENTRY, current_price_api, GENERAL_AND_CONTROLLED_SITE_ID
)
result = await data_service._async_update_data()
@ -233,7 +252,7 @@ async def test_fetch_general_and_feed_in_site(
GENERAL_CHANNEL + FEED_IN_CHANNEL
)
data_service = AmberUpdateCoordinator(
hass, current_price_api, GENERAL_AND_FEED_IN_SITE_ID
hass, MOCKED_ENTRY, current_price_api, GENERAL_AND_FEED_IN_SITE_ID
)
result = await data_service._async_update_data()
@ -273,7 +292,9 @@ async def test_fetch_potential_spike(
]
general_channel[0].actual_instance.spike_status = SpikeStatus.POTENTIAL
current_price_api.get_current_prices.return_value = general_channel
data_service = AmberUpdateCoordinator(hass, current_price_api, GENERAL_ONLY_SITE_ID)
data_service = AmberUpdateCoordinator(
hass, MOCKED_ENTRY, current_price_api, GENERAL_ONLY_SITE_ID
)
result = await data_service._async_update_data()
assert result["grid"]["price_spike"] == "potential"
@ -288,6 +309,8 @@ async def test_fetch_spike(hass: HomeAssistant, current_price_api: Mock) -> None
]
general_channel[0].actual_instance.spike_status = SpikeStatus.SPIKE
current_price_api.get_current_prices.return_value = general_channel
data_service = AmberUpdateCoordinator(hass, current_price_api, GENERAL_ONLY_SITE_ID)
data_service = AmberUpdateCoordinator(
hass, MOCKED_ENTRY, current_price_api, GENERAL_ONLY_SITE_ID
)
result = await data_service._async_update_data()
assert result["grid"]["price_spike"] == "spike"