From 16485af7fc6ce9fa02478fd30a2526661e718201 Mon Sep 17 00:00:00 2001 From: Joost Lekkerkerker Date: Mon, 3 Jun 2024 19:23:07 +0200 Subject: [PATCH] Configure device in airgradient config flow (#118699) --- .../components/airgradient/config_flow.py | 20 +++++-- .../components/airgradient/strings.json | 3 +- .../airgradient/test_config_flow.py | 56 ++++++++++++++++++- 3 files changed, 69 insertions(+), 10 deletions(-) diff --git a/homeassistant/components/airgradient/config_flow.py b/homeassistant/components/airgradient/config_flow.py index c02ec2a469f..c7b617de272 100644 --- a/homeassistant/components/airgradient/config_flow.py +++ b/homeassistant/components/airgradient/config_flow.py @@ -2,7 +2,7 @@ from typing import Any -from airgradient import AirGradientClient, AirGradientError +from airgradient import AirGradientClient, AirGradientError, ConfigurationControl import voluptuous as vol from homeassistant.components import zeroconf @@ -19,6 +19,14 @@ class AirGradientConfigFlow(ConfigFlow, domain=DOMAIN): def __init__(self) -> None: """Initialize the config flow.""" self.data: dict[str, Any] = {} + self.client: AirGradientClient | None = None + + async def set_configuration_source(self) -> None: + """Set configuration source to local if it hasn't been set yet.""" + assert self.client + config = await self.client.get_config() + if config.configuration_control is ConfigurationControl.BOTH: + await self.client.set_configuration_control(ConfigurationControl.LOCAL) async def async_step_zeroconf( self, discovery_info: zeroconf.ZeroconfServiceInfo @@ -31,8 +39,8 @@ class AirGradientConfigFlow(ConfigFlow, domain=DOMAIN): self._abort_if_unique_id_configured(updates={CONF_HOST: host}) session = async_get_clientsession(self.hass) - air_gradient = AirGradientClient(host, session=session) - await air_gradient.get_current_measures() + self.client = AirGradientClient(host, session=session) + await self.client.get_current_measures() self.context["title_placeholders"] = { "model": self.data[CONF_MODEL], @@ -44,6 +52,7 @@ class AirGradientConfigFlow(ConfigFlow, domain=DOMAIN): ) -> ConfigFlowResult: """Confirm discovery.""" if user_input is not None: + await self.set_configuration_source() return self.async_create_entry( title=self.data[CONF_MODEL], data={CONF_HOST: self.data[CONF_HOST]}, @@ -64,14 +73,15 @@ class AirGradientConfigFlow(ConfigFlow, domain=DOMAIN): errors: dict[str, str] = {} if user_input: session = async_get_clientsession(self.hass) - air_gradient = AirGradientClient(user_input[CONF_HOST], session=session) + self.client = AirGradientClient(user_input[CONF_HOST], session=session) try: - current_measures = await air_gradient.get_current_measures() + current_measures = await self.client.get_current_measures() except AirGradientError: errors["base"] = "cannot_connect" else: await self.async_set_unique_id(current_measures.serial_number) self._abort_if_unique_id_configured() + await self.set_configuration_source() return self.async_create_entry( title=current_measures.model, data={CONF_HOST: user_input[CONF_HOST]}, diff --git a/homeassistant/components/airgradient/strings.json b/homeassistant/components/airgradient/strings.json index f4441a66209..9deaf17d0e4 100644 --- a/homeassistant/components/airgradient/strings.json +++ b/homeassistant/components/airgradient/strings.json @@ -28,8 +28,7 @@ "name": "Configuration source", "state": { "cloud": "Cloud", - "local": "Local", - "both": "Both" + "local": "Local" } }, "display_temperature_unit": { diff --git a/tests/components/airgradient/test_config_flow.py b/tests/components/airgradient/test_config_flow.py index 022a250ebef..6bb951f2e26 100644 --- a/tests/components/airgradient/test_config_flow.py +++ b/tests/components/airgradient/test_config_flow.py @@ -3,7 +3,7 @@ from ipaddress import ip_address from unittest.mock import AsyncMock -from airgradient import AirGradientConnectionError +from airgradient import AirGradientConnectionError, ConfigurationControl from homeassistant.components.airgradient import DOMAIN from homeassistant.components.zeroconf import ZeroconfServiceInfo @@ -32,7 +32,7 @@ ZEROCONF_DISCOVERY = ZeroconfServiceInfo( async def test_full_flow( hass: HomeAssistant, - mock_airgradient_client: AsyncMock, + mock_new_airgradient_client: AsyncMock, mock_setup_entry: AsyncMock, ) -> None: """Test full flow.""" @@ -55,6 +55,31 @@ async def test_full_flow( CONF_HOST: "10.0.0.131", } assert result["result"].unique_id == "84fce612f5b8" + mock_new_airgradient_client.set_configuration_control.assert_awaited_once_with( + ConfigurationControl.LOCAL + ) + + +async def test_flow_with_registered_device( + hass: HomeAssistant, + mock_cloud_airgradient_client: AsyncMock, + mock_setup_entry: AsyncMock, +) -> None: + """Test we don't revert the cloud setting.""" + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": SOURCE_USER}, + ) + assert result["type"] is FlowResultType.FORM + assert result["step_id"] == "user" + + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + {CONF_HOST: "10.0.0.131"}, + ) + assert result["type"] is FlowResultType.CREATE_ENTRY + assert result["result"].unique_id == "84fce612f5b8" + mock_cloud_airgradient_client.set_configuration_control.assert_not_called() async def test_flow_errors( @@ -123,7 +148,7 @@ async def test_duplicate( async def test_zeroconf_flow( hass: HomeAssistant, - mock_airgradient_client: AsyncMock, + mock_new_airgradient_client: AsyncMock, mock_setup_entry: AsyncMock, ) -> None: """Test zeroconf flow.""" @@ -147,3 +172,28 @@ async def test_zeroconf_flow( CONF_HOST: "10.0.0.131", } assert result["result"].unique_id == "84fce612f5b8" + mock_new_airgradient_client.set_configuration_control.assert_awaited_once_with( + ConfigurationControl.LOCAL + ) + + +async def test_zeroconf_flow_cloud_device( + hass: HomeAssistant, + mock_cloud_airgradient_client: AsyncMock, + mock_setup_entry: AsyncMock, +) -> None: + """Test zeroconf flow doesn't revert the cloud setting.""" + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": SOURCE_ZEROCONF}, + data=ZEROCONF_DISCOVERY, + ) + assert result["type"] is FlowResultType.FORM + assert result["step_id"] == "discovery_confirm" + + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + {}, + ) + assert result["type"] is FlowResultType.CREATE_ENTRY + mock_cloud_airgradient_client.set_configuration_control.assert_not_called()