Update Amberelectric to use amberelectric version 2.0.12 (#125701)

* Add price descriptor attribute to price sensors

* Adding price descriptor sensor

* Use correct number of sensors in spike sensor tests

* Add tests for normalize_descriptor

* Removing debug message

* Removing price_descriptor attribute from the current sensor

* Refactoring everything to use the new API

* Use SiteStatus object, fix some typnig issues

* fixing test

* Adding predicted price to attributes

* Fix advanced price in forecast

* Testing advanced forecasts

* WIP: Adding advanced forecast sensor. need to add attributes, and tests

* Add advanced price attributes

* Adding forecasts to the advanced price sensor

* Appending forecasts corectly

* Appending forecasts correctly. Again

* Removing sensor for the moment. Will do in another PR

* Fix failing test that had the wrong sign

* Adding test to improve coverage on config_flow test

* Bumping amberelectric dependency to version 2

* Remove advanced code from helpers

* Use f-strings

* Bumping to version 2.0.1

* Bumping amberelectric to version 2.0.2

* Bumping amberelectric to version 2.0.2

* Bumping verion amberelectric.py to 2.0.3. Using correct enums

* Bumping amberelectric.py version to 2.0.4

* Bump version to 2.0.5

* Fix formatting

* fixing mocks to include interval_length

* Bumping to 2.0.6

* Bumping to 2.0.7

* Bumping to 2.0.8

* Bumping to 2.0.9

* Bumping version 2.0.12
This commit is contained in:
Myles Eftos 2024-11-20 21:27:24 +11:00 committed by GitHub
parent 2cfacd8bc5
commit 621c66a214
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 352 additions and 254 deletions

View File

@ -1,7 +1,6 @@
"""Support for Amber Electric.""" """Support for Amber Electric."""
from amberelectric import Configuration import amberelectric
from amberelectric.api import amber_api
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_API_TOKEN from homeassistant.const import CONF_API_TOKEN
@ -15,8 +14,9 @@ type AmberConfigEntry = ConfigEntry[AmberUpdateCoordinator]
async def async_setup_entry(hass: HomeAssistant, entry: AmberConfigEntry) -> bool: async def async_setup_entry(hass: HomeAssistant, entry: AmberConfigEntry) -> bool:
"""Set up Amber Electric from a config entry.""" """Set up Amber Electric from a config entry."""
configuration = Configuration(access_token=entry.data[CONF_API_TOKEN]) configuration = amberelectric.Configuration(access_token=entry.data[CONF_API_TOKEN])
api_instance = amber_api.AmberApi.create(configuration) api_client = amberelectric.ApiClient(configuration)
api_instance = amberelectric.AmberApi(api_client)
site_id = entry.data[CONF_SITE_ID] site_id = entry.data[CONF_SITE_ID]
coordinator = AmberUpdateCoordinator(hass, api_instance, site_id) coordinator = AmberUpdateCoordinator(hass, api_instance, site_id)

View File

@ -3,8 +3,8 @@
from __future__ import annotations from __future__ import annotations
import amberelectric import amberelectric
from amberelectric.api import amber_api from amberelectric.models.site import Site
from amberelectric.model.site import Site, SiteStatus from amberelectric.models.site_status import SiteStatus
import voluptuous as vol import voluptuous as vol
from homeassistant.config_entries import ConfigFlow, ConfigFlowResult from homeassistant.config_entries import ConfigFlow, ConfigFlowResult
@ -23,11 +23,15 @@ API_URL = "https://app.amber.com.au/developers"
def generate_site_selector_name(site: Site) -> str: def generate_site_selector_name(site: Site) -> str:
"""Generate the name to show in the site drop down in the configuration flow.""" """Generate the name to show in the site drop down in the configuration flow."""
# For some reason the generated API key returns this as any, not a string. Thanks pydantic
nmi = str(site.nmi)
if site.status == SiteStatus.CLOSED: if site.status == SiteStatus.CLOSED:
return site.nmi + " (Closed: " + site.closed_on.isoformat() + ")" # type: ignore[no-any-return] if site.closed_on is None:
return f"{nmi} (Closed)"
return f"{nmi} (Closed: {site.closed_on.isoformat()})"
if site.status == SiteStatus.PENDING: if site.status == SiteStatus.PENDING:
return site.nmi + " (Pending)" # type: ignore[no-any-return] return f"{nmi} (Pending)"
return site.nmi # type: ignore[no-any-return] return nmi
def filter_sites(sites: list[Site]) -> list[Site]: def filter_sites(sites: list[Site]) -> list[Site]:
@ -35,7 +39,7 @@ def filter_sites(sites: list[Site]) -> list[Site]:
filtered: list[Site] = [] filtered: list[Site] = []
filtered_nmi: set[str] = set() filtered_nmi: set[str] = set()
for site in sorted(sites, key=lambda site: site.status.value): for site in sorted(sites, key=lambda site: site.status):
if site.status == SiteStatus.ACTIVE or site.nmi not in filtered_nmi: if site.status == SiteStatus.ACTIVE or site.nmi not in filtered_nmi:
filtered.append(site) filtered.append(site)
filtered_nmi.add(site.nmi) filtered_nmi.add(site.nmi)
@ -56,7 +60,8 @@ class AmberElectricConfigFlow(ConfigFlow, domain=DOMAIN):
def _fetch_sites(self, token: str) -> list[Site] | None: def _fetch_sites(self, token: str) -> list[Site] | None:
configuration = amberelectric.Configuration(access_token=token) configuration = amberelectric.Configuration(access_token=token)
api: amber_api.AmberApi = amber_api.AmberApi.create(configuration) api_client = amberelectric.ApiClient(configuration)
api = amberelectric.AmberApi(api_client)
try: try:
sites: list[Site] = filter_sites(api.get_sites()) sites: list[Site] = filter_sites(api.get_sites())

View File

@ -5,13 +5,13 @@ from __future__ import annotations
from datetime import timedelta from datetime import timedelta
from typing import Any from typing import Any
from amberelectric import ApiException import amberelectric
from amberelectric.api import amber_api from amberelectric.models.actual_interval import ActualInterval
from amberelectric.model.actual_interval import ActualInterval from amberelectric.models.channel import ChannelType
from amberelectric.model.channel import ChannelType from amberelectric.models.current_interval import CurrentInterval
from amberelectric.model.current_interval import CurrentInterval from amberelectric.models.forecast_interval import ForecastInterval
from amberelectric.model.forecast_interval import ForecastInterval from amberelectric.models.price_descriptor import PriceDescriptor
from amberelectric.model.interval import Descriptor from amberelectric.rest import ApiException
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed
@ -31,22 +31,22 @@ def is_forecast(interval: ActualInterval | CurrentInterval | ForecastInterval) -
def is_general(interval: ActualInterval | CurrentInterval | ForecastInterval) -> bool: def is_general(interval: ActualInterval | CurrentInterval | ForecastInterval) -> bool:
"""Return true if the supplied interval is on the general channel.""" """Return true if the supplied interval is on the general channel."""
return interval.channel_type == ChannelType.GENERAL # type: ignore[no-any-return] return interval.channel_type == ChannelType.GENERAL
def is_controlled_load( def is_controlled_load(
interval: ActualInterval | CurrentInterval | ForecastInterval, interval: ActualInterval | CurrentInterval | ForecastInterval,
) -> bool: ) -> bool:
"""Return true if the supplied interval is on the controlled load channel.""" """Return true if the supplied interval is on the controlled load channel."""
return interval.channel_type == ChannelType.CONTROLLED_LOAD # type: ignore[no-any-return] return interval.channel_type == ChannelType.CONTROLLEDLOAD
def is_feed_in(interval: ActualInterval | CurrentInterval | ForecastInterval) -> bool: def is_feed_in(interval: ActualInterval | CurrentInterval | ForecastInterval) -> bool:
"""Return true if the supplied interval is on the feed in channel.""" """Return true if the supplied interval is on the feed in channel."""
return interval.channel_type == ChannelType.FEED_IN # type: ignore[no-any-return] return interval.channel_type == ChannelType.FEEDIN
def normalize_descriptor(descriptor: Descriptor) -> str | None: def normalize_descriptor(descriptor: PriceDescriptor | None) -> str | None:
"""Return the snake case versions of descriptor names. Returns None if the name is not recognized.""" """Return the snake case versions of descriptor names. Returns None if the name is not recognized."""
if descriptor is None: if descriptor is None:
return None return None
@ -71,7 +71,7 @@ class AmberUpdateCoordinator(DataUpdateCoordinator):
"""AmberUpdateCoordinator - In charge of downloading the data for a site, which all the sensors read.""" """AmberUpdateCoordinator - In charge of downloading the data for a site, which all the sensors read."""
def __init__( def __init__(
self, hass: HomeAssistant, api: amber_api.AmberApi, site_id: str self, hass: HomeAssistant, api: amberelectric.AmberApi, site_id: str
) -> None: ) -> None:
"""Initialise the data service.""" """Initialise the data service."""
super().__init__( super().__init__(
@ -93,12 +93,13 @@ class AmberUpdateCoordinator(DataUpdateCoordinator):
"grid": {}, "grid": {},
} }
try: try:
data = self._api.get_current_price(self.site_id, next=48) data = self._api.get_current_prices(self.site_id, next=48)
intervals = [interval.actual_instance for interval in data]
except ApiException as api_exception: except ApiException as api_exception:
raise UpdateFailed("Missing price data, skipping update") from api_exception raise UpdateFailed("Missing price data, skipping update") from api_exception
current = [interval for interval in data if is_current(interval)] current = [interval for interval in intervals if is_current(interval)]
forecasts = [interval for interval in data if is_forecast(interval)] forecasts = [interval for interval in intervals if is_forecast(interval)]
general = [interval for interval in current if is_general(interval)] general = [interval for interval in current if is_general(interval)]
if len(general) == 0: if len(general) == 0:
@ -137,7 +138,7 @@ class AmberUpdateCoordinator(DataUpdateCoordinator):
interval for interval in forecasts if is_feed_in(interval) interval for interval in forecasts if is_feed_in(interval)
] ]
LOGGER.debug("Fetched new Amber data: %s", data) LOGGER.debug("Fetched new Amber data: %s", intervals)
return result return result
async def _async_update_data(self) -> dict[str, Any]: async def _async_update_data(self) -> dict[str, Any]:

View File

@ -6,5 +6,5 @@
"documentation": "https://www.home-assistant.io/integrations/amberelectric", "documentation": "https://www.home-assistant.io/integrations/amberelectric",
"iot_class": "cloud_polling", "iot_class": "cloud_polling",
"loggers": ["amberelectric"], "loggers": ["amberelectric"],
"requirements": ["amberelectric==1.1.1"] "requirements": ["amberelectric==2.0.12"]
} }

View File

@ -8,9 +8,9 @@ from __future__ import annotations
from typing import Any from typing import Any
from amberelectric.model.channel import ChannelType from amberelectric.models.channel import ChannelType
from amberelectric.model.current_interval import CurrentInterval from amberelectric.models.current_interval import CurrentInterval
from amberelectric.model.forecast_interval import ForecastInterval from amberelectric.models.forecast_interval import ForecastInterval
from homeassistant.components.sensor import ( from homeassistant.components.sensor import (
SensorEntity, SensorEntity,
@ -52,7 +52,7 @@ class AmberSensor(CoordinatorEntity[AmberUpdateCoordinator], SensorEntity):
self, self,
coordinator: AmberUpdateCoordinator, coordinator: AmberUpdateCoordinator,
description: SensorEntityDescription, description: SensorEntityDescription,
channel_type: ChannelType, channel_type: str,
) -> None: ) -> None:
"""Initialize the Sensor.""" """Initialize the Sensor."""
super().__init__(coordinator) super().__init__(coordinator)
@ -73,7 +73,7 @@ class AmberPriceSensor(AmberSensor):
"""Return the current price in $/kWh.""" """Return the current price in $/kWh."""
interval = self.coordinator.data[self.entity_description.key][self.channel_type] interval = self.coordinator.data[self.entity_description.key][self.channel_type]
if interval.channel_type == ChannelType.FEED_IN: if interval.channel_type == ChannelType.FEEDIN:
return format_cents_to_dollars(interval.per_kwh) * -1 return format_cents_to_dollars(interval.per_kwh) * -1
return format_cents_to_dollars(interval.per_kwh) return format_cents_to_dollars(interval.per_kwh)
@ -87,9 +87,9 @@ class AmberPriceSensor(AmberSensor):
return data return data
data["duration"] = interval.duration data["duration"] = interval.duration
data["date"] = interval.date.isoformat() data["date"] = interval.var_date.isoformat()
data["per_kwh"] = format_cents_to_dollars(interval.per_kwh) data["per_kwh"] = format_cents_to_dollars(interval.per_kwh)
if interval.channel_type == ChannelType.FEED_IN: if interval.channel_type == ChannelType.FEEDIN:
data["per_kwh"] = data["per_kwh"] * -1 data["per_kwh"] = data["per_kwh"] * -1
data["nem_date"] = interval.nem_time.isoformat() data["nem_date"] = interval.nem_time.isoformat()
data["spot_per_kwh"] = format_cents_to_dollars(interval.spot_per_kwh) data["spot_per_kwh"] = format_cents_to_dollars(interval.spot_per_kwh)
@ -120,7 +120,7 @@ class AmberForecastSensor(AmberSensor):
return None return None
interval = intervals[0] interval = intervals[0]
if interval.channel_type == ChannelType.FEED_IN: if interval.channel_type == ChannelType.FEEDIN:
return format_cents_to_dollars(interval.per_kwh) * -1 return format_cents_to_dollars(interval.per_kwh) * -1
return format_cents_to_dollars(interval.per_kwh) return format_cents_to_dollars(interval.per_kwh)
@ -142,10 +142,10 @@ class AmberForecastSensor(AmberSensor):
for interval in intervals: for interval in intervals:
datum = {} datum = {}
datum["duration"] = interval.duration datum["duration"] = interval.duration
datum["date"] = interval.date.isoformat() datum["date"] = interval.var_date.isoformat()
datum["nem_date"] = interval.nem_time.isoformat() datum["nem_date"] = interval.nem_time.isoformat()
datum["per_kwh"] = format_cents_to_dollars(interval.per_kwh) datum["per_kwh"] = format_cents_to_dollars(interval.per_kwh)
if interval.channel_type == ChannelType.FEED_IN: if interval.channel_type == ChannelType.FEEDIN:
datum["per_kwh"] = datum["per_kwh"] * -1 datum["per_kwh"] = datum["per_kwh"] * -1
datum["spot_per_kwh"] = format_cents_to_dollars(interval.spot_per_kwh) datum["spot_per_kwh"] = format_cents_to_dollars(interval.spot_per_kwh)
datum["start_time"] = interval.start_time.isoformat() datum["start_time"] = interval.start_time.isoformat()

View File

@ -447,7 +447,7 @@ airtouch5py==0.2.10
alpha-vantage==2.3.1 alpha-vantage==2.3.1
# homeassistant.components.amberelectric # homeassistant.components.amberelectric
amberelectric==1.1.1 amberelectric==2.0.12
# homeassistant.components.amcrest # homeassistant.components.amcrest
amcrest==1.9.8 amcrest==1.9.8

View File

@ -426,7 +426,7 @@ airtouch4pyapi==1.0.5
airtouch5py==0.2.10 airtouch5py==0.2.10
# homeassistant.components.amberelectric # homeassistant.components.amberelectric
amberelectric==1.1.1 amberelectric==2.0.12
# homeassistant.components.androidtv # homeassistant.components.androidtv
androidtv[async]==0.0.73 androidtv[async]==0.0.73

View File

@ -2,20 +2,22 @@
from datetime import datetime, timedelta from datetime import datetime, timedelta
from amberelectric.model.actual_interval import ActualInterval from amberelectric.models.actual_interval import ActualInterval
from amberelectric.model.channel import ChannelType from amberelectric.models.channel import ChannelType
from amberelectric.model.current_interval import CurrentInterval from amberelectric.models.current_interval import CurrentInterval
from amberelectric.model.forecast_interval import ForecastInterval from amberelectric.models.forecast_interval import ForecastInterval
from amberelectric.model.interval import Descriptor, SpikeStatus from amberelectric.models.interval import Interval
from amberelectric.models.price_descriptor import PriceDescriptor
from amberelectric.models.spike_status import SpikeStatus
from dateutil import parser from dateutil import parser
def generate_actual_interval( def generate_actual_interval(channel_type: ChannelType, end_time: datetime) -> Interval:
channel_type: ChannelType, end_time: datetime
) -> ActualInterval:
"""Generate a mock actual interval.""" """Generate a mock actual interval."""
start_time = end_time - timedelta(minutes=30) start_time = end_time - timedelta(minutes=30)
return ActualInterval( return Interval(
ActualInterval(
type="ActualInterval",
duration=30, duration=30,
spot_per_kwh=1.0, spot_per_kwh=1.0,
per_kwh=8.0, per_kwh=8.0,
@ -24,18 +26,21 @@ def generate_actual_interval(
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
renewables=50, renewables=50,
channel_type=channel_type.value, channel_type=channel_type,
spike_status=SpikeStatus.NO_SPIKE.value, spike_status=SpikeStatus.NONE,
descriptor=Descriptor.LOW.value, descriptor=PriceDescriptor.LOW,
)
) )
def generate_current_interval( def generate_current_interval(
channel_type: ChannelType, end_time: datetime channel_type: ChannelType, end_time: datetime
) -> CurrentInterval: ) -> Interval:
"""Generate a mock current price.""" """Generate a mock current price."""
start_time = end_time - timedelta(minutes=30) start_time = end_time - timedelta(minutes=30)
return CurrentInterval( return Interval(
CurrentInterval(
type="CurrentInterval",
duration=30, duration=30,
spot_per_kwh=1.0, spot_per_kwh=1.0,
per_kwh=8.0, per_kwh=8.0,
@ -44,19 +49,22 @@ def generate_current_interval(
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
renewables=50.6, renewables=50.6,
channel_type=channel_type.value, channel_type=channel_type,
spike_status=SpikeStatus.NO_SPIKE.value, spike_status=SpikeStatus.NONE,
descriptor=Descriptor.EXTREMELY_LOW.value, descriptor=PriceDescriptor.EXTREMELYLOW,
estimate=True, estimate=True,
) )
)
def generate_forecast_interval( def generate_forecast_interval(
channel_type: ChannelType, end_time: datetime channel_type: ChannelType, end_time: datetime
) -> ForecastInterval: ) -> Interval:
"""Generate a mock forecast interval.""" """Generate a mock forecast interval."""
start_time = end_time - timedelta(minutes=30) start_time = end_time - timedelta(minutes=30)
return ForecastInterval( return Interval(
ForecastInterval(
type="ForecastInterval",
duration=30, duration=30,
spot_per_kwh=1.1, spot_per_kwh=1.1,
per_kwh=8.8, per_kwh=8.8,
@ -65,11 +73,12 @@ def generate_forecast_interval(
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
renewables=50, renewables=50,
channel_type=channel_type.value, channel_type=channel_type,
spike_status=SpikeStatus.NO_SPIKE.value, spike_status=SpikeStatus.NONE,
descriptor=Descriptor.VERY_LOW.value, descriptor=PriceDescriptor.VERYLOW,
estimate=True, estimate=True,
) )
)
GENERAL_ONLY_SITE_ID = "01FG2K6V5TB6X9W0EWPPMZD6MJ" GENERAL_ONLY_SITE_ID = "01FG2K6V5TB6X9W0EWPPMZD6MJ"
@ -94,31 +103,31 @@ GENERAL_CHANNEL = [
CONTROLLED_LOAD_CHANNEL = [ CONTROLLED_LOAD_CHANNEL = [
generate_current_interval( generate_current_interval(
ChannelType.CONTROLLED_LOAD, parser.parse("2021-09-21T08:30:00+10:00") ChannelType.CONTROLLEDLOAD, parser.parse("2021-09-21T08:30:00+10:00")
), ),
generate_forecast_interval( generate_forecast_interval(
ChannelType.CONTROLLED_LOAD, parser.parse("2021-09-21T09:00:00+10:00") ChannelType.CONTROLLEDLOAD, parser.parse("2021-09-21T09:00:00+10:00")
), ),
generate_forecast_interval( generate_forecast_interval(
ChannelType.CONTROLLED_LOAD, parser.parse("2021-09-21T09:30:00+10:00") ChannelType.CONTROLLEDLOAD, parser.parse("2021-09-21T09:30:00+10:00")
), ),
generate_forecast_interval( generate_forecast_interval(
ChannelType.CONTROLLED_LOAD, parser.parse("2021-09-21T10:00:00+10:00") ChannelType.CONTROLLEDLOAD, parser.parse("2021-09-21T10:00:00+10:00")
), ),
] ]
FEED_IN_CHANNEL = [ FEED_IN_CHANNEL = [
generate_current_interval( generate_current_interval(
ChannelType.FEED_IN, parser.parse("2021-09-21T08:30:00+10:00") ChannelType.FEEDIN, parser.parse("2021-09-21T08:30:00+10:00")
), ),
generate_forecast_interval( generate_forecast_interval(
ChannelType.FEED_IN, parser.parse("2021-09-21T09:00:00+10:00") ChannelType.FEEDIN, parser.parse("2021-09-21T09:00:00+10:00")
), ),
generate_forecast_interval( generate_forecast_interval(
ChannelType.FEED_IN, parser.parse("2021-09-21T09:30:00+10:00") ChannelType.FEEDIN, parser.parse("2021-09-21T09:30:00+10:00")
), ),
generate_forecast_interval( generate_forecast_interval(
ChannelType.FEED_IN, parser.parse("2021-09-21T10:00:00+10:00") ChannelType.FEEDIN, parser.parse("2021-09-21T10:00:00+10:00")
), ),
] ]

View File

@ -5,10 +5,10 @@ from __future__ import annotations
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from amberelectric.model.channel import ChannelType from amberelectric.models.channel import ChannelType
from amberelectric.model.current_interval import CurrentInterval from amberelectric.models.current_interval import CurrentInterval
from amberelectric.model.interval import SpikeStatus from amberelectric.models.spike_status import SpikeStatus
from amberelectric.model.tariff_information import TariffInformation from amberelectric.models.tariff_information import TariffInformation
from dateutil import parser from dateutil import parser
import pytest import pytest
@ -42,10 +42,10 @@ async def setup_no_spike(hass: HomeAssistant) -> AsyncGenerator[Mock]:
instance = Mock() instance = Mock()
with patch( with patch(
"amberelectric.api.AmberApi.create", "amberelectric.AmberApi",
return_value=instance, return_value=instance,
) as mock_update: ) as mock_update:
instance.get_current_price = Mock(return_value=GENERAL_CHANNEL) instance.get_current_prices = Mock(return_value=GENERAL_CHANNEL)
assert await async_setup_component(hass, DOMAIN, {}) assert await async_setup_component(hass, DOMAIN, {})
await hass.async_block_till_done() await hass.async_block_till_done()
yield mock_update.return_value yield mock_update.return_value
@ -65,7 +65,7 @@ async def setup_potential_spike(hass: HomeAssistant) -> AsyncGenerator[Mock]:
instance = Mock() instance = Mock()
with patch( with patch(
"amberelectric.api.AmberApi.create", "amberelectric.AmberApi",
return_value=instance, return_value=instance,
) as mock_update: ) as mock_update:
general_channel: list[CurrentInterval] = [ general_channel: list[CurrentInterval] = [
@ -73,8 +73,8 @@ async def setup_potential_spike(hass: HomeAssistant) -> AsyncGenerator[Mock]:
ChannelType.GENERAL, parser.parse("2021-09-21T08:30:00+10:00") ChannelType.GENERAL, parser.parse("2021-09-21T08:30:00+10:00")
), ),
] ]
general_channel[0].spike_status = SpikeStatus.POTENTIAL general_channel[0].actual_instance.spike_status = SpikeStatus.POTENTIAL
instance.get_current_price = Mock(return_value=general_channel) instance.get_current_prices = Mock(return_value=general_channel)
assert await async_setup_component(hass, DOMAIN, {}) assert await async_setup_component(hass, DOMAIN, {})
await hass.async_block_till_done() await hass.async_block_till_done()
yield mock_update.return_value yield mock_update.return_value
@ -94,7 +94,7 @@ async def setup_spike(hass: HomeAssistant) -> AsyncGenerator[Mock]:
instance = Mock() instance = Mock()
with patch( with patch(
"amberelectric.api.AmberApi.create", "amberelectric.AmberApi",
return_value=instance, return_value=instance,
) as mock_update: ) as mock_update:
general_channel: list[CurrentInterval] = [ general_channel: list[CurrentInterval] = [
@ -102,8 +102,8 @@ async def setup_spike(hass: HomeAssistant) -> AsyncGenerator[Mock]:
ChannelType.GENERAL, parser.parse("2021-09-21T08:30:00+10:00") ChannelType.GENERAL, parser.parse("2021-09-21T08:30:00+10:00")
), ),
] ]
general_channel[0].spike_status = SpikeStatus.SPIKE general_channel[0].actual_instance.spike_status = SpikeStatus.SPIKE
instance.get_current_price = Mock(return_value=general_channel) instance.get_current_prices = Mock(return_value=general_channel)
assert await async_setup_component(hass, DOMAIN, {}) assert await async_setup_component(hass, DOMAIN, {})
await hass.async_block_till_done() await hass.async_block_till_done()
yield mock_update.return_value yield mock_update.return_value
@ -156,7 +156,7 @@ async def setup_inactive_demand_window(hass: HomeAssistant) -> AsyncGenerator[Mo
instance = Mock() instance = Mock()
with patch( with patch(
"amberelectric.api.AmberApi.create", "amberelectric.AmberApi",
return_value=instance, return_value=instance,
) as mock_update: ) as mock_update:
general_channel: list[CurrentInterval] = [ general_channel: list[CurrentInterval] = [
@ -164,8 +164,10 @@ async def setup_inactive_demand_window(hass: HomeAssistant) -> AsyncGenerator[Mo
ChannelType.GENERAL, parser.parse("2021-09-21T08:30:00+10:00") ChannelType.GENERAL, parser.parse("2021-09-21T08:30:00+10:00")
), ),
] ]
general_channel[0].tariff_information = TariffInformation(demandWindow=False) general_channel[0].actual_instance.tariff_information = TariffInformation(
instance.get_current_price = Mock(return_value=general_channel) demandWindow=False
)
instance.get_current_prices = Mock(return_value=general_channel)
assert await async_setup_component(hass, DOMAIN, {}) assert await async_setup_component(hass, DOMAIN, {})
await hass.async_block_till_done() await hass.async_block_till_done()
yield mock_update.return_value yield mock_update.return_value
@ -185,7 +187,7 @@ async def setup_active_demand_window(hass: HomeAssistant) -> AsyncGenerator[Mock
instance = Mock() instance = Mock()
with patch( with patch(
"amberelectric.api.AmberApi.create", "amberelectric.AmberApi",
return_value=instance, return_value=instance,
) as mock_update: ) as mock_update:
general_channel: list[CurrentInterval] = [ general_channel: list[CurrentInterval] = [
@ -193,8 +195,10 @@ async def setup_active_demand_window(hass: HomeAssistant) -> AsyncGenerator[Mock
ChannelType.GENERAL, parser.parse("2021-09-21T08:30:00+10:00") ChannelType.GENERAL, parser.parse("2021-09-21T08:30:00+10:00")
), ),
] ]
general_channel[0].tariff_information = TariffInformation(demandWindow=True) general_channel[0].actual_instance.tariff_information = TariffInformation(
instance.get_current_price = Mock(return_value=general_channel) demandWindow=True
)
instance.get_current_prices = Mock(return_value=general_channel)
assert await async_setup_component(hass, DOMAIN, {}) assert await async_setup_component(hass, DOMAIN, {})
await hass.async_block_till_done() await hass.async_block_till_done()
yield mock_update.return_value yield mock_update.return_value

View File

@ -5,7 +5,8 @@ from datetime import date
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from amberelectric import ApiException from amberelectric import ApiException
from amberelectric.model.site import Site, SiteStatus from amberelectric.models.site import Site
from amberelectric.models.site_status import SiteStatus
import pytest import pytest
from homeassistant.components.amberelectric.config_flow import filter_sites from homeassistant.components.amberelectric.config_flow import filter_sites
@ -28,7 +29,7 @@ pytestmark = pytest.mark.usefixtures("mock_setup_entry")
def mock_invalid_key_api() -> Generator: def mock_invalid_key_api() -> Generator:
"""Return an authentication error.""" """Return an authentication error."""
with patch("amberelectric.api.AmberApi.create") as mock: with patch("amberelectric.AmberApi") as mock:
mock.return_value.get_sites.side_effect = ApiException(status=403) mock.return_value.get_sites.side_effect = ApiException(status=403)
yield mock yield mock
@ -36,7 +37,7 @@ def mock_invalid_key_api() -> Generator:
@pytest.fixture(name="api_error") @pytest.fixture(name="api_error")
def mock_api_error() -> Generator: def mock_api_error() -> Generator:
"""Return an authentication error.""" """Return an authentication error."""
with patch("amberelectric.api.AmberApi.create") as mock: with patch("amberelectric.AmberApi") as mock:
mock.return_value.get_sites.side_effect = ApiException(status=500) mock.return_value.get_sites.side_effect = ApiException(status=500)
yield mock yield mock
@ -45,16 +46,36 @@ def mock_api_error() -> Generator:
def mock_single_site_api() -> Generator: def mock_single_site_api() -> Generator:
"""Return a single site.""" """Return a single site."""
site = Site( site = Site(
"01FG0AGP818PXK0DWHXJRRT2DH", id="01FG0AGP818PXK0DWHXJRRT2DH",
"11111111111", nmi="11111111111",
[], channels=[],
"Jemena", network="Jemena",
SiteStatus.ACTIVE, status=SiteStatus.ACTIVE,
date(2002, 1, 1), active_from=date(2002, 1, 1),
None, closed_on=None,
interval_length=30,
) )
with patch("amberelectric.api.AmberApi.create") as mock: with patch("amberelectric.AmberApi") as mock:
mock.return_value.get_sites.return_value = [site]
yield mock
@pytest.fixture(name="single_site_closed_no_close_date_api")
def single_site_closed_no_close_date_api() -> Generator:
"""Return a single closed site with no closed date."""
site = Site(
id="01FG0AGP818PXK0DWHXJRRT2DH",
nmi="11111111111",
channels=[],
network="Jemena",
status=SiteStatus.CLOSED,
active_from=None,
closed_on=None,
interval_length=30,
)
with patch("amberelectric.AmberApi") as mock:
mock.return_value.get_sites.return_value = [site] mock.return_value.get_sites.return_value = [site]
yield mock yield mock
@ -63,16 +84,17 @@ def mock_single_site_api() -> Generator:
def mock_single_site_pending_api() -> Generator: def mock_single_site_pending_api() -> Generator:
"""Return a single site.""" """Return a single site."""
site = Site( site = Site(
"01FG0AGP818PXK0DWHXJRRT2DH", id="01FG0AGP818PXK0DWHXJRRT2DH",
"11111111111", nmi="11111111111",
[], channels=[],
"Jemena", network="Jemena",
SiteStatus.PENDING, status=SiteStatus.PENDING,
None, active_from=None,
None, closed_on=None,
interval_length=30,
) )
with patch("amberelectric.api.AmberApi.create") as mock: with patch("amberelectric.AmberApi") as mock:
mock.return_value.get_sites.return_value = [site] mock.return_value.get_sites.return_value = [site]
yield mock yield mock
@ -82,35 +104,38 @@ def mock_single_site_rejoin_api() -> Generator:
"""Return a single site.""" """Return a single site."""
instance = Mock() instance = Mock()
site_1 = Site( site_1 = Site(
"01HGD9QB72HB3DWQNJ6SSCGXGV", id="01HGD9QB72HB3DWQNJ6SSCGXGV",
"11111111111", nmi="11111111111",
[], channels=[],
"Jemena", network="Jemena",
SiteStatus.CLOSED, status=SiteStatus.CLOSED,
date(2002, 1, 1), active_from=date(2002, 1, 1),
date(2002, 6, 1), closed_on=date(2002, 6, 1),
interval_length=30,
) )
site_2 = Site( site_2 = Site(
"01FG0AGP818PXK0DWHXJRRT2DH", id="01FG0AGP818PXK0DWHXJRRT2DH",
"11111111111", nmi="11111111111",
[], channels=[],
"Jemena", network="Jemena",
SiteStatus.ACTIVE, status=SiteStatus.ACTIVE,
date(2003, 1, 1), active_from=date(2003, 1, 1),
None, closed_on=None,
interval_length=30,
) )
site_3 = Site( site_3 = Site(
"01FG0AGP818PXK0DWHXJRRT2DH", id="01FG0AGP818PXK0DWHXJRRT2DH",
"11111111112", nmi="11111111112",
[], channels=[],
"Jemena", network="Jemena",
SiteStatus.CLOSED, status=SiteStatus.CLOSED,
date(2003, 1, 1), active_from=date(2003, 1, 1),
date(2003, 6, 1), closed_on=date(2003, 6, 1),
interval_length=30,
) )
instance.get_sites.return_value = [site_1, site_2, site_3] instance.get_sites.return_value = [site_1, site_2, site_3]
with patch("amberelectric.api.AmberApi.create", return_value=instance): with patch("amberelectric.AmberApi", return_value=instance):
yield instance yield instance
@ -120,7 +145,7 @@ def mock_no_site_api() -> Generator:
instance = Mock() instance = Mock()
instance.get_sites.return_value = [] instance.get_sites.return_value = []
with patch("amberelectric.api.AmberApi.create", return_value=instance): with patch("amberelectric.AmberApi", return_value=instance):
yield instance yield instance
@ -188,6 +213,39 @@ async def test_single_site(hass: HomeAssistant, single_site_api: Mock) -> None:
assert data[CONF_SITE_ID] == "01FG0AGP818PXK0DWHXJRRT2DH" assert data[CONF_SITE_ID] == "01FG0AGP818PXK0DWHXJRRT2DH"
async def test_single_closed_site_no_closed_date(
hass: HomeAssistant, single_site_closed_no_close_date_api: Mock
) -> None:
"""Test single closed site with no closed date."""
initial_result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_USER}
)
assert initial_result.get("type") is FlowResultType.FORM
assert initial_result.get("step_id") == "user"
# Test filling in API key
enter_api_key_result = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": SOURCE_USER},
data={CONF_API_TOKEN: API_KEY},
)
assert enter_api_key_result.get("type") is FlowResultType.FORM
assert enter_api_key_result.get("step_id") == "site"
select_site_result = await hass.config_entries.flow.async_configure(
enter_api_key_result["flow_id"],
{CONF_SITE_ID: "01FG0AGP818PXK0DWHXJRRT2DH", CONF_SITE_NAME: "Home"},
)
# Show available sites
assert select_site_result.get("type") is FlowResultType.CREATE_ENTRY
assert select_site_result.get("title") == "Home"
data = select_site_result.get("data")
assert data
assert data[CONF_API_TOKEN] == API_KEY
assert data[CONF_SITE_ID] == "01FG0AGP818PXK0DWHXJRRT2DH"
async def test_single_site_rejoin( async def test_single_site_rejoin(
hass: HomeAssistant, single_site_rejoin_api: Mock hass: HomeAssistant, single_site_rejoin_api: Mock
) -> None: ) -> None:

View File

@ -7,10 +7,12 @@ from datetime import date
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from amberelectric import ApiException from amberelectric import ApiException
from amberelectric.model.channel import Channel, ChannelType from amberelectric.models.channel import Channel, ChannelType
from amberelectric.model.current_interval import CurrentInterval from amberelectric.models.interval import Interval
from amberelectric.model.interval import Descriptor, SpikeStatus from amberelectric.models.price_descriptor import PriceDescriptor
from amberelectric.model.site import Site, SiteStatus from amberelectric.models.site import Site
from amberelectric.models.site_status import SiteStatus
from amberelectric.models.spike_status import SpikeStatus
from dateutil import parser from dateutil import parser
import pytest import pytest
@ -38,37 +40,40 @@ def mock_api_current_price() -> Generator:
instance = Mock() instance = Mock()
general_site = Site( general_site = Site(
GENERAL_ONLY_SITE_ID, id=GENERAL_ONLY_SITE_ID,
"11111111111", nmi="11111111111",
[Channel(identifier="E1", type=ChannelType.GENERAL, tariff="A100")], channels=[Channel(identifier="E1", type=ChannelType.GENERAL, tariff="A100")],
"Jemena", network="Jemena",
SiteStatus.ACTIVE, status=SiteStatus("active"),
date(2021, 1, 1), activeFrom=date(2021, 1, 1),
None, closedOn=None,
interval_length=30,
) )
general_and_controlled_load = Site( general_and_controlled_load = Site(
GENERAL_AND_CONTROLLED_SITE_ID, id=GENERAL_AND_CONTROLLED_SITE_ID,
"11111111112", nmi="11111111112",
[ channels=[
Channel(identifier="E1", type=ChannelType.GENERAL, tariff="A100"), Channel(identifier="E1", type=ChannelType.GENERAL, tariff="A100"),
Channel(identifier="E2", type=ChannelType.CONTROLLED_LOAD, tariff="A180"), Channel(identifier="E2", type=ChannelType.CONTROLLEDLOAD, tariff="A180"),
], ],
"Jemena", network="Jemena",
SiteStatus.ACTIVE, status=SiteStatus("active"),
date(2021, 1, 1), activeFrom=date(2021, 1, 1),
None, closedOn=None,
interval_length=30,
) )
general_and_feed_in = Site( general_and_feed_in = Site(
GENERAL_AND_FEED_IN_SITE_ID, id=GENERAL_AND_FEED_IN_SITE_ID,
"11111111113", nmi="11111111113",
[ channels=[
Channel(identifier="E1", type=ChannelType.GENERAL, tariff="A100"), Channel(identifier="E1", type=ChannelType.GENERAL, tariff="A100"),
Channel(identifier="E2", type=ChannelType.FEED_IN, tariff="A100"), Channel(identifier="E2", type=ChannelType.FEEDIN, tariff="A100"),
], ],
"Jemena", network="Jemena",
SiteStatus.ACTIVE, status=SiteStatus("active"),
date(2021, 1, 1), activeFrom=date(2021, 1, 1),
None, closedOn=None,
interval_length=30,
) )
instance.get_sites.return_value = [ instance.get_sites.return_value = [
general_site, general_site,
@ -76,44 +81,46 @@ def mock_api_current_price() -> Generator:
general_and_feed_in, general_and_feed_in,
] ]
with patch("amberelectric.api.AmberApi.create", return_value=instance): with patch("amberelectric.AmberApi", return_value=instance):
yield instance yield instance
def test_normalize_descriptor() -> None: def test_normalize_descriptor() -> None:
"""Test normalizing descriptors works correctly.""" """Test normalizing descriptors works correctly."""
assert normalize_descriptor(None) is None assert normalize_descriptor(None) is None
assert normalize_descriptor(Descriptor.NEGATIVE) == "negative" assert normalize_descriptor(PriceDescriptor.NEGATIVE) == "negative"
assert normalize_descriptor(Descriptor.EXTREMELY_LOW) == "extremely_low" assert normalize_descriptor(PriceDescriptor.EXTREMELYLOW) == "extremely_low"
assert normalize_descriptor(Descriptor.VERY_LOW) == "very_low" assert normalize_descriptor(PriceDescriptor.VERYLOW) == "very_low"
assert normalize_descriptor(Descriptor.LOW) == "low" assert normalize_descriptor(PriceDescriptor.LOW) == "low"
assert normalize_descriptor(Descriptor.NEUTRAL) == "neutral" assert normalize_descriptor(PriceDescriptor.NEUTRAL) == "neutral"
assert normalize_descriptor(Descriptor.HIGH) == "high" assert normalize_descriptor(PriceDescriptor.HIGH) == "high"
assert normalize_descriptor(Descriptor.SPIKE) == "spike" assert normalize_descriptor(PriceDescriptor.SPIKE) == "spike"
async def test_fetch_general_site(hass: HomeAssistant, current_price_api: Mock) -> None: async def test_fetch_general_site(hass: HomeAssistant, current_price_api: Mock) -> None:
"""Test fetching a site with only a general channel.""" """Test fetching a site with only a general channel."""
current_price_api.get_current_price.return_value = GENERAL_CHANNEL current_price_api.get_current_prices.return_value = GENERAL_CHANNEL
data_service = AmberUpdateCoordinator(hass, current_price_api, GENERAL_ONLY_SITE_ID) data_service = AmberUpdateCoordinator(hass, current_price_api, GENERAL_ONLY_SITE_ID)
result = await data_service._async_update_data() result = await data_service._async_update_data()
current_price_api.get_current_price.assert_called_with( current_price_api.get_current_prices.assert_called_with(
GENERAL_ONLY_SITE_ID, next=48 GENERAL_ONLY_SITE_ID, next=48
) )
assert result["current"].get("general") == GENERAL_CHANNEL[0] assert result["current"].get("general") == GENERAL_CHANNEL[0].actual_instance
assert result["forecasts"].get("general") == [ assert result["forecasts"].get("general") == [
GENERAL_CHANNEL[1], GENERAL_CHANNEL[1].actual_instance,
GENERAL_CHANNEL[2], GENERAL_CHANNEL[2].actual_instance,
GENERAL_CHANNEL[3], GENERAL_CHANNEL[3].actual_instance,
] ]
assert result["current"].get("controlled_load") is None assert result["current"].get("controlled_load") is None
assert result["forecasts"].get("controlled_load") is None assert result["forecasts"].get("controlled_load") is None
assert result["current"].get("feed_in") is None assert result["current"].get("feed_in") is None
assert result["forecasts"].get("feed_in") is None assert result["forecasts"].get("feed_in") is None
assert result["grid"]["renewables"] == round(GENERAL_CHANNEL[0].renewables) assert result["grid"]["renewables"] == round(
GENERAL_CHANNEL[0].actual_instance.renewables
)
assert result["grid"]["price_spike"] == "none" assert result["grid"]["price_spike"] == "none"
@ -122,12 +129,12 @@ async def test_fetch_no_general_site(
) -> None: ) -> None:
"""Test fetching a site with no general channel.""" """Test fetching a site with no general channel."""
current_price_api.get_current_price.return_value = CONTROLLED_LOAD_CHANNEL current_price_api.get_current_prices.return_value = CONTROLLED_LOAD_CHANNEL
data_service = AmberUpdateCoordinator(hass, current_price_api, GENERAL_ONLY_SITE_ID) data_service = AmberUpdateCoordinator(hass, current_price_api, GENERAL_ONLY_SITE_ID)
with pytest.raises(UpdateFailed): with pytest.raises(UpdateFailed):
await data_service._async_update_data() await data_service._async_update_data()
current_price_api.get_current_price.assert_called_with( current_price_api.get_current_prices.assert_called_with(
GENERAL_ONLY_SITE_ID, next=48 GENERAL_ONLY_SITE_ID, next=48
) )
@ -135,41 +142,45 @@ async def test_fetch_no_general_site(
async def test_fetch_api_error(hass: HomeAssistant, current_price_api: Mock) -> None: async def test_fetch_api_error(hass: HomeAssistant, current_price_api: Mock) -> None:
"""Test that the old values are maintained if a second call fails.""" """Test that the old values are maintained if a second call fails."""
current_price_api.get_current_price.return_value = GENERAL_CHANNEL current_price_api.get_current_prices.return_value = GENERAL_CHANNEL
data_service = AmberUpdateCoordinator(hass, current_price_api, GENERAL_ONLY_SITE_ID) data_service = AmberUpdateCoordinator(hass, current_price_api, GENERAL_ONLY_SITE_ID)
result = await data_service._async_update_data() result = await data_service._async_update_data()
current_price_api.get_current_price.assert_called_with( current_price_api.get_current_prices.assert_called_with(
GENERAL_ONLY_SITE_ID, next=48 GENERAL_ONLY_SITE_ID, next=48
) )
assert result["current"].get("general") == GENERAL_CHANNEL[0] assert result["current"].get("general") == GENERAL_CHANNEL[0].actual_instance
assert result["forecasts"].get("general") == [ assert result["forecasts"].get("general") == [
GENERAL_CHANNEL[1], GENERAL_CHANNEL[1].actual_instance,
GENERAL_CHANNEL[2], GENERAL_CHANNEL[2].actual_instance,
GENERAL_CHANNEL[3], GENERAL_CHANNEL[3].actual_instance,
] ]
assert result["current"].get("controlled_load") is None assert result["current"].get("controlled_load") is None
assert result["forecasts"].get("controlled_load") is None assert result["forecasts"].get("controlled_load") is None
assert result["current"].get("feed_in") is None assert result["current"].get("feed_in") is None
assert result["forecasts"].get("feed_in") is None assert result["forecasts"].get("feed_in") is None
assert result["grid"]["renewables"] == round(GENERAL_CHANNEL[0].renewables) assert result["grid"]["renewables"] == round(
GENERAL_CHANNEL[0].actual_instance.renewables
)
current_price_api.get_current_price.side_effect = ApiException(status=403) current_price_api.get_current_prices.side_effect = ApiException(status=403)
with pytest.raises(UpdateFailed): with pytest.raises(UpdateFailed):
await data_service._async_update_data() await data_service._async_update_data()
assert result["current"].get("general") == GENERAL_CHANNEL[0] assert result["current"].get("general") == GENERAL_CHANNEL[0].actual_instance
assert result["forecasts"].get("general") == [ assert result["forecasts"].get("general") == [
GENERAL_CHANNEL[1], GENERAL_CHANNEL[1].actual_instance,
GENERAL_CHANNEL[2], GENERAL_CHANNEL[2].actual_instance,
GENERAL_CHANNEL[3], GENERAL_CHANNEL[3].actual_instance,
] ]
assert result["current"].get("controlled_load") is None assert result["current"].get("controlled_load") is None
assert result["forecasts"].get("controlled_load") is None assert result["forecasts"].get("controlled_load") is None
assert result["current"].get("feed_in") is None assert result["current"].get("feed_in") is None
assert result["forecasts"].get("feed_in") is None assert result["forecasts"].get("feed_in") is None
assert result["grid"]["renewables"] == round(GENERAL_CHANNEL[0].renewables) assert result["grid"]["renewables"] == round(
GENERAL_CHANNEL[0].actual_instance.renewables
)
assert result["grid"]["price_spike"] == "none" assert result["grid"]["price_spike"] == "none"
@ -178,7 +189,7 @@ async def test_fetch_general_and_controlled_load_site(
) -> None: ) -> None:
"""Test fetching a site with a general and controlled load channel.""" """Test fetching a site with a general and controlled load channel."""
current_price_api.get_current_price.return_value = ( current_price_api.get_current_prices.return_value = (
GENERAL_CHANNEL + CONTROLLED_LOAD_CHANNEL GENERAL_CHANNEL + CONTROLLED_LOAD_CHANNEL
) )
data_service = AmberUpdateCoordinator( data_service = AmberUpdateCoordinator(
@ -186,25 +197,30 @@ async def test_fetch_general_and_controlled_load_site(
) )
result = await data_service._async_update_data() result = await data_service._async_update_data()
current_price_api.get_current_price.assert_called_with( current_price_api.get_current_prices.assert_called_with(
GENERAL_AND_CONTROLLED_SITE_ID, next=48 GENERAL_AND_CONTROLLED_SITE_ID, next=48
) )
assert result["current"].get("general") == GENERAL_CHANNEL[0] assert result["current"].get("general") == GENERAL_CHANNEL[0].actual_instance
assert result["forecasts"].get("general") == [ assert result["forecasts"].get("general") == [
GENERAL_CHANNEL[1], GENERAL_CHANNEL[1].actual_instance,
GENERAL_CHANNEL[2], GENERAL_CHANNEL[2].actual_instance,
GENERAL_CHANNEL[3], GENERAL_CHANNEL[3].actual_instance,
] ]
assert result["current"].get("controlled_load") is CONTROLLED_LOAD_CHANNEL[0] assert (
result["current"].get("controlled_load")
is CONTROLLED_LOAD_CHANNEL[0].actual_instance
)
assert result["forecasts"].get("controlled_load") == [ assert result["forecasts"].get("controlled_load") == [
CONTROLLED_LOAD_CHANNEL[1], CONTROLLED_LOAD_CHANNEL[1].actual_instance,
CONTROLLED_LOAD_CHANNEL[2], CONTROLLED_LOAD_CHANNEL[2].actual_instance,
CONTROLLED_LOAD_CHANNEL[3], CONTROLLED_LOAD_CHANNEL[3].actual_instance,
] ]
assert result["current"].get("feed_in") is None assert result["current"].get("feed_in") is None
assert result["forecasts"].get("feed_in") is None assert result["forecasts"].get("feed_in") is None
assert result["grid"]["renewables"] == round(GENERAL_CHANNEL[0].renewables) assert result["grid"]["renewables"] == round(
GENERAL_CHANNEL[0].actual_instance.renewables
)
assert result["grid"]["price_spike"] == "none" assert result["grid"]["price_spike"] == "none"
@ -213,31 +229,35 @@ async def test_fetch_general_and_feed_in_site(
) -> None: ) -> None:
"""Test fetching a site with a general and feed_in channel.""" """Test fetching a site with a general and feed_in channel."""
current_price_api.get_current_price.return_value = GENERAL_CHANNEL + FEED_IN_CHANNEL current_price_api.get_current_prices.return_value = (
GENERAL_CHANNEL + FEED_IN_CHANNEL
)
data_service = AmberUpdateCoordinator( data_service = AmberUpdateCoordinator(
hass, current_price_api, GENERAL_AND_FEED_IN_SITE_ID hass, current_price_api, GENERAL_AND_FEED_IN_SITE_ID
) )
result = await data_service._async_update_data() result = await data_service._async_update_data()
current_price_api.get_current_price.assert_called_with( current_price_api.get_current_prices.assert_called_with(
GENERAL_AND_FEED_IN_SITE_ID, next=48 GENERAL_AND_FEED_IN_SITE_ID, next=48
) )
assert result["current"].get("general") == GENERAL_CHANNEL[0] assert result["current"].get("general") == GENERAL_CHANNEL[0].actual_instance
assert result["forecasts"].get("general") == [ assert result["forecasts"].get("general") == [
GENERAL_CHANNEL[1], GENERAL_CHANNEL[1].actual_instance,
GENERAL_CHANNEL[2], GENERAL_CHANNEL[2].actual_instance,
GENERAL_CHANNEL[3], GENERAL_CHANNEL[3].actual_instance,
] ]
assert result["current"].get("controlled_load") is None assert result["current"].get("controlled_load") is None
assert result["forecasts"].get("controlled_load") is None assert result["forecasts"].get("controlled_load") is None
assert result["current"].get("feed_in") is FEED_IN_CHANNEL[0] assert result["current"].get("feed_in") is FEED_IN_CHANNEL[0].actual_instance
assert result["forecasts"].get("feed_in") == [ assert result["forecasts"].get("feed_in") == [
FEED_IN_CHANNEL[1], FEED_IN_CHANNEL[1].actual_instance,
FEED_IN_CHANNEL[2], FEED_IN_CHANNEL[2].actual_instance,
FEED_IN_CHANNEL[3], FEED_IN_CHANNEL[3].actual_instance,
] ]
assert result["grid"]["renewables"] == round(GENERAL_CHANNEL[0].renewables) assert result["grid"]["renewables"] == round(
GENERAL_CHANNEL[0].actual_instance.renewables
)
assert result["grid"]["price_spike"] == "none" assert result["grid"]["price_spike"] == "none"
@ -246,13 +266,13 @@ async def test_fetch_potential_spike(
) -> None: ) -> None:
"""Test fetching a site with only a general channel.""" """Test fetching a site with only a general channel."""
general_channel: list[CurrentInterval] = [ general_channel: list[Interval] = [
generate_current_interval( generate_current_interval(
ChannelType.GENERAL, parser.parse("2021-09-21T08:30:00+10:00") ChannelType.GENERAL, parser.parse("2021-09-21T08:30:00+10:00")
), )
] ]
general_channel[0].spike_status = SpikeStatus.POTENTIAL general_channel[0].actual_instance.spike_status = SpikeStatus.POTENTIAL
current_price_api.get_current_price.return_value = general_channel current_price_api.get_current_prices.return_value = general_channel
data_service = AmberUpdateCoordinator(hass, current_price_api, GENERAL_ONLY_SITE_ID) data_service = AmberUpdateCoordinator(hass, current_price_api, GENERAL_ONLY_SITE_ID)
result = await data_service._async_update_data() result = await data_service._async_update_data()
assert result["grid"]["price_spike"] == "potential" assert result["grid"]["price_spike"] == "potential"
@ -261,13 +281,13 @@ async def test_fetch_potential_spike(
async def test_fetch_spike(hass: HomeAssistant, current_price_api: Mock) -> None: async def test_fetch_spike(hass: HomeAssistant, current_price_api: Mock) -> None:
"""Test fetching a site with only a general channel.""" """Test fetching a site with only a general channel."""
general_channel: list[CurrentInterval] = [ general_channel: list[Interval] = [
generate_current_interval( generate_current_interval(
ChannelType.GENERAL, parser.parse("2021-09-21T08:30:00+10:00") ChannelType.GENERAL, parser.parse("2021-09-21T08:30:00+10:00")
), )
] ]
general_channel[0].spike_status = SpikeStatus.SPIKE general_channel[0].actual_instance.spike_status = SpikeStatus.SPIKE
current_price_api.get_current_price.return_value = general_channel current_price_api.get_current_prices.return_value = general_channel
data_service = AmberUpdateCoordinator(hass, current_price_api, GENERAL_ONLY_SITE_ID) data_service = AmberUpdateCoordinator(hass, current_price_api, GENERAL_ONLY_SITE_ID)
result = await data_service._async_update_data() result = await data_service._async_update_data()
assert result["grid"]["price_spike"] == "spike" assert result["grid"]["price_spike"] == "spike"

View File

@ -3,8 +3,9 @@
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from unittest.mock import Mock, patch from unittest.mock import Mock, patch
from amberelectric.model.current_interval import CurrentInterval from amberelectric.models.current_interval import CurrentInterval
from amberelectric.model.range import Range from amberelectric.models.interval import Interval
from amberelectric.models.range import Range
import pytest import pytest
from homeassistant.components.amberelectric.const import ( from homeassistant.components.amberelectric.const import (
@ -44,10 +45,10 @@ async def setup_general(hass: HomeAssistant) -> AsyncGenerator[Mock]:
instance = Mock() instance = Mock()
with patch( with patch(
"amberelectric.api.AmberApi.create", "amberelectric.AmberApi",
return_value=instance, return_value=instance,
) as mock_update: ) as mock_update:
instance.get_current_price = Mock(return_value=GENERAL_CHANNEL) instance.get_current_prices = Mock(return_value=GENERAL_CHANNEL)
assert await async_setup_component(hass, DOMAIN, {}) assert await async_setup_component(hass, DOMAIN, {})
await hass.async_block_till_done() await hass.async_block_till_done()
yield mock_update.return_value yield mock_update.return_value
@ -68,10 +69,10 @@ async def setup_general_and_controlled_load(
instance = Mock() instance = Mock()
with patch( with patch(
"amberelectric.api.AmberApi.create", "amberelectric.AmberApi",
return_value=instance, return_value=instance,
) as mock_update: ) as mock_update:
instance.get_current_price = Mock( instance.get_current_prices = Mock(
return_value=GENERAL_CHANNEL + CONTROLLED_LOAD_CHANNEL return_value=GENERAL_CHANNEL + CONTROLLED_LOAD_CHANNEL
) )
assert await async_setup_component(hass, DOMAIN, {}) assert await async_setup_component(hass, DOMAIN, {})
@ -92,10 +93,10 @@ async def setup_general_and_feed_in(hass: HomeAssistant) -> AsyncGenerator[Mock]
instance = Mock() instance = Mock()
with patch( with patch(
"amberelectric.api.AmberApi.create", "amberelectric.AmberApi",
return_value=instance, return_value=instance,
) as mock_update: ) as mock_update:
instance.get_current_price = Mock( instance.get_current_prices = Mock(
return_value=GENERAL_CHANNEL + FEED_IN_CHANNEL return_value=GENERAL_CHANNEL + FEED_IN_CHANNEL
) )
assert await async_setup_component(hass, DOMAIN, {}) assert await async_setup_component(hass, DOMAIN, {})
@ -126,7 +127,7 @@ async def test_general_price_sensor(hass: HomeAssistant, setup_general: Mock) ->
assert attributes.get("range_max") is None assert attributes.get("range_max") is None
with_range: list[CurrentInterval] = GENERAL_CHANNEL with_range: list[CurrentInterval] = GENERAL_CHANNEL
with_range[0].range = Range(7.8, 12.4) with_range[0].actual_instance.range = Range(min=7.8, max=12.4)
setup_general.get_current_price.return_value = with_range setup_general.get_current_price.return_value = with_range
config_entry = hass.config_entries.async_entries(DOMAIN)[0] config_entry = hass.config_entries.async_entries(DOMAIN)[0]
@ -211,8 +212,8 @@ async def test_general_forecast_sensor(
assert first_forecast.get("range_min") is None assert first_forecast.get("range_min") is None
assert first_forecast.get("range_max") is None assert first_forecast.get("range_max") is None
with_range: list[CurrentInterval] = GENERAL_CHANNEL with_range: list[Interval] = GENERAL_CHANNEL
with_range[1].range = Range(7.8, 12.4) with_range[1].actual_instance.range = Range(min=7.8, max=12.4)
setup_general.get_current_price.return_value = with_range setup_general.get_current_price.return_value = with_range
config_entry = hass.config_entries.async_entries(DOMAIN)[0] config_entry = hass.config_entries.async_entries(DOMAIN)[0]