Configure device in airgradient config flow (#118699)

This commit is contained in:
Joost Lekkerkerker 2024-06-03 19:23:07 +02:00 committed by GitHub
parent 91ca7db02f
commit 16485af7fc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 69 additions and 10 deletions

View File

@ -2,7 +2,7 @@
from typing import Any
from airgradient import AirGradientClient, AirGradientError
from airgradient import AirGradientClient, AirGradientError, ConfigurationControl
import voluptuous as vol
from homeassistant.components import zeroconf
@ -19,6 +19,14 @@ class AirGradientConfigFlow(ConfigFlow, domain=DOMAIN):
def __init__(self) -> None:
"""Initialize the config flow."""
self.data: dict[str, Any] = {}
self.client: AirGradientClient | None = None
async def set_configuration_source(self) -> None:
"""Set configuration source to local if it hasn't been set yet."""
assert self.client
config = await self.client.get_config()
if config.configuration_control is ConfigurationControl.BOTH:
await self.client.set_configuration_control(ConfigurationControl.LOCAL)
async def async_step_zeroconf(
self, discovery_info: zeroconf.ZeroconfServiceInfo
@ -31,8 +39,8 @@ class AirGradientConfigFlow(ConfigFlow, domain=DOMAIN):
self._abort_if_unique_id_configured(updates={CONF_HOST: host})
session = async_get_clientsession(self.hass)
air_gradient = AirGradientClient(host, session=session)
await air_gradient.get_current_measures()
self.client = AirGradientClient(host, session=session)
await self.client.get_current_measures()
self.context["title_placeholders"] = {
"model": self.data[CONF_MODEL],
@ -44,6 +52,7 @@ class AirGradientConfigFlow(ConfigFlow, domain=DOMAIN):
) -> ConfigFlowResult:
"""Confirm discovery."""
if user_input is not None:
await self.set_configuration_source()
return self.async_create_entry(
title=self.data[CONF_MODEL],
data={CONF_HOST: self.data[CONF_HOST]},
@ -64,14 +73,15 @@ class AirGradientConfigFlow(ConfigFlow, domain=DOMAIN):
errors: dict[str, str] = {}
if user_input:
session = async_get_clientsession(self.hass)
air_gradient = AirGradientClient(user_input[CONF_HOST], session=session)
self.client = AirGradientClient(user_input[CONF_HOST], session=session)
try:
current_measures = await air_gradient.get_current_measures()
current_measures = await self.client.get_current_measures()
except AirGradientError:
errors["base"] = "cannot_connect"
else:
await self.async_set_unique_id(current_measures.serial_number)
self._abort_if_unique_id_configured()
await self.set_configuration_source()
return self.async_create_entry(
title=current_measures.model,
data={CONF_HOST: user_input[CONF_HOST]},

View File

@ -28,8 +28,7 @@
"name": "Configuration source",
"state": {
"cloud": "Cloud",
"local": "Local",
"both": "Both"
"local": "Local"
}
},
"display_temperature_unit": {

View File

@ -3,7 +3,7 @@
from ipaddress import ip_address
from unittest.mock import AsyncMock
from airgradient import AirGradientConnectionError
from airgradient import AirGradientConnectionError, ConfigurationControl
from homeassistant.components.airgradient import DOMAIN
from homeassistant.components.zeroconf import ZeroconfServiceInfo
@ -32,7 +32,7 @@ ZEROCONF_DISCOVERY = ZeroconfServiceInfo(
async def test_full_flow(
hass: HomeAssistant,
mock_airgradient_client: AsyncMock,
mock_new_airgradient_client: AsyncMock,
mock_setup_entry: AsyncMock,
) -> None:
"""Test full flow."""
@ -55,6 +55,31 @@ async def test_full_flow(
CONF_HOST: "10.0.0.131",
}
assert result["result"].unique_id == "84fce612f5b8"
mock_new_airgradient_client.set_configuration_control.assert_awaited_once_with(
ConfigurationControl.LOCAL
)
async def test_flow_with_registered_device(
hass: HomeAssistant,
mock_cloud_airgradient_client: AsyncMock,
mock_setup_entry: AsyncMock,
) -> None:
"""Test we don't revert the cloud setting."""
result = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": SOURCE_USER},
)
assert result["type"] is FlowResultType.FORM
assert result["step_id"] == "user"
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{CONF_HOST: "10.0.0.131"},
)
assert result["type"] is FlowResultType.CREATE_ENTRY
assert result["result"].unique_id == "84fce612f5b8"
mock_cloud_airgradient_client.set_configuration_control.assert_not_called()
async def test_flow_errors(
@ -123,7 +148,7 @@ async def test_duplicate(
async def test_zeroconf_flow(
hass: HomeAssistant,
mock_airgradient_client: AsyncMock,
mock_new_airgradient_client: AsyncMock,
mock_setup_entry: AsyncMock,
) -> None:
"""Test zeroconf flow."""
@ -147,3 +172,28 @@ async def test_zeroconf_flow(
CONF_HOST: "10.0.0.131",
}
assert result["result"].unique_id == "84fce612f5b8"
mock_new_airgradient_client.set_configuration_control.assert_awaited_once_with(
ConfigurationControl.LOCAL
)
async def test_zeroconf_flow_cloud_device(
hass: HomeAssistant,
mock_cloud_airgradient_client: AsyncMock,
mock_setup_entry: AsyncMock,
) -> None:
"""Test zeroconf flow doesn't revert the cloud setting."""
result = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": SOURCE_ZEROCONF},
data=ZEROCONF_DISCOVERY,
)
assert result["type"] is FlowResultType.FORM
assert result["step_id"] == "discovery_confirm"
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{},
)
assert result["type"] is FlowResultType.CREATE_ENTRY
mock_cloud_airgradient_client.set_configuration_control.assert_not_called()