diff --git a/homeassistant/components/amberelectric/config_flow.py b/homeassistant/components/amberelectric/config_flow.py index 4011f442ee2..765e219b6d7 100644 --- a/homeassistant/components/amberelectric/config_flow.py +++ b/homeassistant/components/amberelectric/config_flow.py @@ -3,18 +3,46 @@ from __future__ import annotations import amberelectric from amberelectric.api import amber_api -from amberelectric.model.site import Site +from amberelectric.model.site import Site, SiteStatus import voluptuous as vol from homeassistant import config_entries from homeassistant.const import CONF_API_TOKEN from homeassistant.data_entry_flow import FlowResult +from homeassistant.helpers.selector import ( + SelectOptionDict, + SelectSelector, + SelectSelectorConfig, + SelectSelectorMode, +) -from .const import CONF_SITE_ID, CONF_SITE_NAME, CONF_SITE_NMI, DOMAIN +from .const import CONF_SITE_ID, CONF_SITE_NAME, DOMAIN API_URL = "https://app.amber.com.au/developers" +def generate_site_selector_name(site: Site) -> str: + """Generate the name to show in the site drop down in the configuration flow.""" + if site.status == SiteStatus.CLOSED: + return site.nmi + " (Closed: " + site.closed_on.isoformat() + ")" # type: ignore[no-any-return] + if site.status == SiteStatus.PENDING: + return site.nmi + " (Pending)" # type: ignore[no-any-return] + return site.nmi # type: ignore[no-any-return] + + +def filter_sites(sites: list[Site]) -> list[Site]: + """Deduplicates the list of sites.""" + filtered: list[Site] = [] + filtered_nmi: set[str] = set() + + for site in sorted(sites, key=lambda site: site.status.value): + if site.status == SiteStatus.ACTIVE or site.nmi not in filtered_nmi: + filtered.append(site) + filtered_nmi.add(site.nmi) + + return filtered + + class AmberElectricConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): """Handle a config flow.""" @@ -31,7 +59,7 @@ class AmberElectricConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): api: amber_api.AmberApi = amber_api.AmberApi.create(configuration) try: - sites: list[Site] = api.get_sites() + sites: list[Site] = filter_sites(api.get_sites()) if len(sites) == 0: self._errors[CONF_API_TOKEN] = "no_site" return None @@ -86,38 +114,31 @@ class AmberElectricConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): assert self._sites is not None assert self._api_token is not None - api_token = self._api_token if user_input is not None: - site_nmi = user_input[CONF_SITE_NMI] - sites = [site for site in self._sites if site.nmi == site_nmi] - site = sites[0] - site_id = site.id + site_id = user_input[CONF_SITE_ID] name = user_input.get(CONF_SITE_NAME, site_id) return self.async_create_entry( title=name, - data={ - CONF_SITE_ID: site_id, - CONF_API_TOKEN: api_token, - CONF_SITE_NMI: site.nmi, - }, + data={CONF_SITE_ID: site_id, CONF_API_TOKEN: self._api_token}, ) - user_input = { - CONF_API_TOKEN: api_token, - CONF_SITE_NMI: "", - CONF_SITE_NAME: "", - } - return self.async_show_form( step_id="site", data_schema=vol.Schema( { - vol.Required( - CONF_SITE_NMI, default=user_input[CONF_SITE_NMI] - ): vol.In([site.nmi for site in self._sites]), - vol.Optional( - CONF_SITE_NAME, default=user_input[CONF_SITE_NAME] - ): str, + vol.Required(CONF_SITE_ID): SelectSelector( + SelectSelectorConfig( + options=[ + SelectOptionDict( + value=site.id, + label=generate_site_selector_name(site), + ) + for site in self._sites + ], + mode=SelectSelectorMode.DROPDOWN, + ) + ), + vol.Optional(CONF_SITE_NAME): str, } ), errors=self._errors, diff --git a/homeassistant/components/amberelectric/const.py b/homeassistant/components/amberelectric/const.py index 8416b7ca33c..6166b21c19f 100644 --- a/homeassistant/components/amberelectric/const.py +++ b/homeassistant/components/amberelectric/const.py @@ -6,7 +6,6 @@ from homeassistant.const import Platform DOMAIN = "amberelectric" CONF_SITE_NAME = "site_name" CONF_SITE_ID = "site_id" -CONF_SITE_NMI = "site_nmi" ATTRIBUTION = "Data provided by Amber Electric" diff --git a/homeassistant/components/amberelectric/manifest.json b/homeassistant/components/amberelectric/manifest.json index 29de18d96de..13a9f257adb 100644 --- a/homeassistant/components/amberelectric/manifest.json +++ b/homeassistant/components/amberelectric/manifest.json @@ -6,5 +6,5 @@ "documentation": "https://www.home-assistant.io/integrations/amberelectric", "iot_class": "cloud_polling", "loggers": ["amberelectric"], - "requirements": ["amberelectric==1.0.4"] + "requirements": ["amberelectric==1.1.0"] } diff --git a/requirements_all.txt b/requirements_all.txt index d70de10c4b5..4927a6d048a 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -425,7 +425,7 @@ airtouch5py==0.2.8 alpha-vantage==2.3.1 # homeassistant.components.amberelectric -amberelectric==1.0.4 +amberelectric==1.1.0 # homeassistant.components.amcrest amcrest==1.9.8 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index e18168ee8cb..d233f035c60 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -395,7 +395,7 @@ airtouch4pyapi==1.0.5 airtouch5py==0.2.8 # homeassistant.components.amberelectric -amberelectric==1.0.4 +amberelectric==1.1.0 # homeassistant.components.androidtv androidtv[async]==0.0.73 diff --git a/tests/components/amberelectric/test_config_flow.py b/tests/components/amberelectric/test_config_flow.py index 6325282aff8..2624bd96d31 100644 --- a/tests/components/amberelectric/test_config_flow.py +++ b/tests/components/amberelectric/test_config_flow.py @@ -1,17 +1,18 @@ """Tests for the Amber config flow.""" from collections.abc import Generator +from datetime import date from unittest.mock import Mock, patch from amberelectric import ApiException -from amberelectric.model.site import Site +from amberelectric.model.site import Site, SiteStatus import pytest from homeassistant import data_entry_flow +from homeassistant.components.amberelectric.config_flow import filter_sites from homeassistant.components.amberelectric.const import ( CONF_SITE_ID, CONF_SITE_NAME, - CONF_SITE_NMI, DOMAIN, ) from homeassistant.config_entries import SOURCE_USER @@ -26,29 +27,88 @@ pytestmark = pytest.mark.usefixtures("mock_setup_entry") @pytest.fixture(name="invalid_key_api") def mock_invalid_key_api() -> Generator: """Return an authentication error.""" - instance = Mock() - instance.get_sites.side_effect = ApiException(status=403) - with patch("amberelectric.api.AmberApi.create", return_value=instance): - yield instance + with patch("amberelectric.api.AmberApi.create") as mock: + mock.return_value.get_sites.side_effect = ApiException(status=403) + yield mock @pytest.fixture(name="api_error") def mock_api_error() -> Generator: """Return an authentication error.""" - instance = Mock() - instance.get_sites.side_effect = ApiException(status=500) - - with patch("amberelectric.api.AmberApi.create", return_value=instance): - yield instance + with patch("amberelectric.api.AmberApi.create") as mock: + mock.return_value.get_sites.side_effect = ApiException(status=500) + yield mock @pytest.fixture(name="single_site_api") def mock_single_site_api() -> Generator: + """Return a single site.""" + site = Site( + "01FG0AGP818PXK0DWHXJRRT2DH", + "11111111111", + [], + "Jemena", + SiteStatus.ACTIVE, + date(2002, 1, 1), + None, + ) + + with patch("amberelectric.api.AmberApi.create") as mock: + mock.return_value.get_sites.return_value = [site] + yield mock + + +@pytest.fixture(name="single_site_pending_api") +def mock_single_site_pending_api() -> Generator: + """Return a single site.""" + site = Site( + "01FG0AGP818PXK0DWHXJRRT2DH", + "11111111111", + [], + "Jemena", + SiteStatus.PENDING, + None, + None, + ) + + with patch("amberelectric.api.AmberApi.create") as mock: + mock.return_value.get_sites.return_value = [site] + yield mock + + +@pytest.fixture(name="single_site_rejoin_api") +def mock_single_site_rejoin_api() -> Generator: """Return a single site.""" instance = Mock() - site = Site("01FG0AGP818PXK0DWHXJRRT2DH", "11111111111", []) - instance.get_sites.return_value = [site] + site_1 = Site( + "01HGD9QB72HB3DWQNJ6SSCGXGV", + "11111111111", + [], + "Jemena", + SiteStatus.CLOSED, + date(2002, 1, 1), + date(2002, 6, 1), + ) + site_2 = Site( + "01FG0AGP818PXK0DWHXJRRT2DH", + "11111111111", + [], + "Jemena", + SiteStatus.ACTIVE, + date(2003, 1, 1), + None, + ) + site_3 = Site( + "01FG0AGP818PXK0DWHXJRRT2DH", + "11111111112", + [], + "Jemena", + SiteStatus.CLOSED, + date(2003, 1, 1), + date(2003, 6, 1), + ) + instance.get_sites.return_value = [site_1, site_2, site_3] with patch("amberelectric.api.AmberApi.create", return_value=instance): yield instance @@ -64,6 +124,39 @@ def mock_no_site_api() -> Generator: yield instance +async def test_single_pending_site( + hass: HomeAssistant, single_site_pending_api: Mock +) -> None: + """Test single site.""" + initial_result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": SOURCE_USER} + ) + assert initial_result.get("type") == data_entry_flow.FlowResultType.FORM + assert initial_result.get("step_id") == "user" + + # Test filling in API key + enter_api_key_result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": SOURCE_USER}, + data={CONF_API_TOKEN: API_KEY}, + ) + assert enter_api_key_result.get("type") == data_entry_flow.FlowResultType.FORM + assert enter_api_key_result.get("step_id") == "site" + + select_site_result = await hass.config_entries.flow.async_configure( + enter_api_key_result["flow_id"], + {CONF_SITE_ID: "01FG0AGP818PXK0DWHXJRRT2DH", CONF_SITE_NAME: "Home"}, + ) + + # Show available sites + assert select_site_result.get("type") == data_entry_flow.FlowResultType.CREATE_ENTRY + assert select_site_result.get("title") == "Home" + data = select_site_result.get("data") + assert data + assert data[CONF_API_TOKEN] == API_KEY + assert data[CONF_SITE_ID] == "01FG0AGP818PXK0DWHXJRRT2DH" + + async def test_single_site(hass: HomeAssistant, single_site_api: Mock) -> None: """Test single site.""" initial_result = await hass.config_entries.flow.async_init( @@ -83,7 +176,40 @@ async def test_single_site(hass: HomeAssistant, single_site_api: Mock) -> None: select_site_result = await hass.config_entries.flow.async_configure( enter_api_key_result["flow_id"], - {CONF_SITE_NMI: "11111111111", CONF_SITE_NAME: "Home"}, + {CONF_SITE_ID: "01FG0AGP818PXK0DWHXJRRT2DH", CONF_SITE_NAME: "Home"}, + ) + + # Show available sites + assert select_site_result.get("type") == data_entry_flow.FlowResultType.CREATE_ENTRY + assert select_site_result.get("title") == "Home" + data = select_site_result.get("data") + assert data + assert data[CONF_API_TOKEN] == API_KEY + assert data[CONF_SITE_ID] == "01FG0AGP818PXK0DWHXJRRT2DH" + + +async def test_single_site_rejoin( + hass: HomeAssistant, single_site_rejoin_api: Mock +) -> None: + """Test single site.""" + initial_result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": SOURCE_USER} + ) + assert initial_result.get("type") == data_entry_flow.FlowResultType.FORM + assert initial_result.get("step_id") == "user" + + # Test filling in API key + enter_api_key_result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": SOURCE_USER}, + data={CONF_API_TOKEN: API_KEY}, + ) + assert enter_api_key_result.get("type") == data_entry_flow.FlowResultType.FORM + assert enter_api_key_result.get("step_id") == "site" + + select_site_result = await hass.config_entries.flow.async_configure( + enter_api_key_result["flow_id"], + {CONF_SITE_ID: "01FG0AGP818PXK0DWHXJRRT2DH", CONF_SITE_NAME: "Home"}, ) # Show available sites @@ -93,7 +219,6 @@ async def test_single_site(hass: HomeAssistant, single_site_api: Mock) -> None: assert data assert data[CONF_API_TOKEN] == API_KEY assert data[CONF_SITE_ID] == "01FG0AGP818PXK0DWHXJRRT2DH" - assert data[CONF_SITE_NMI] == "11111111111" async def test_no_site(hass: HomeAssistant, no_site_api: Mock) -> None: @@ -148,3 +273,15 @@ async def test_unknown_error(hass: HomeAssistant, api_error: Mock) -> None: # Goes back to the user step assert result.get("step_id") == "user" assert result.get("errors") == {"api_token": "unknown_error"} + + +async def test_site_deduplication(single_site_rejoin_api: Mock) -> None: + """Test site deduplication.""" + filtered = filter_sites(single_site_rejoin_api.get_sites()) + assert len(filtered) == 2 + assert ( + next(s for s in filtered if s.nmi == "11111111111").status == SiteStatus.ACTIVE + ) + assert ( + next(s for s in filtered if s.nmi == "11111111112").status == SiteStatus.CLOSED + ) diff --git a/tests/components/amberelectric/test_coordinator.py b/tests/components/amberelectric/test_coordinator.py index 64fa39192a6..7808d1adcde 100644 --- a/tests/components/amberelectric/test_coordinator.py +++ b/tests/components/amberelectric/test_coordinator.py @@ -2,13 +2,14 @@ from __future__ import annotations from collections.abc import Generator +from datetime import date from unittest.mock import Mock, patch from amberelectric import ApiException from amberelectric.model.channel import Channel, ChannelType from amberelectric.model.current_interval import CurrentInterval from amberelectric.model.interval import Descriptor, SpikeStatus -from amberelectric.model.site import Site +from amberelectric.model.site import Site, SiteStatus from dateutil import parser import pytest @@ -38,23 +39,35 @@ def mock_api_current_price() -> Generator: general_site = Site( GENERAL_ONLY_SITE_ID, "11111111111", - [Channel(identifier="E1", type=ChannelType.GENERAL)], + [Channel(identifier="E1", type=ChannelType.GENERAL, tariff="A100")], + "Jemena", + SiteStatus.ACTIVE, + date(2021, 1, 1), + None, ) general_and_controlled_load = Site( GENERAL_AND_CONTROLLED_SITE_ID, "11111111112", [ - Channel(identifier="E1", type=ChannelType.GENERAL), - Channel(identifier="E2", type=ChannelType.CONTROLLED_LOAD), + Channel(identifier="E1", type=ChannelType.GENERAL, tariff="A100"), + Channel(identifier="E2", type=ChannelType.CONTROLLED_LOAD, tariff="A180"), ], + "Jemena", + SiteStatus.ACTIVE, + date(2021, 1, 1), + None, ) general_and_feed_in = Site( GENERAL_AND_FEED_IN_SITE_ID, "11111111113", [ - Channel(identifier="E1", type=ChannelType.GENERAL), - Channel(identifier="E2", type=ChannelType.FEED_IN), + Channel(identifier="E1", type=ChannelType.GENERAL, tariff="A100"), + Channel(identifier="E2", type=ChannelType.FEED_IN, tariff="A100"), ], + "Jemena", + SiteStatus.ACTIVE, + date(2021, 1, 1), + None, ) instance.get_sites.return_value = [ general_site,