From d1f1bdebde65222a8426d37850d8b671980f8ae5 Mon Sep 17 00:00:00 2001 From: Joost Lekkerkerker Date: Thu, 28 Sep 2023 10:55:48 +0200 Subject: [PATCH] Add feature to add measuring station via number in waqi (#99992) * Add feature to add measuring station via number * Add feature to add measuring station via number * Add feature to add measuring station via number --- homeassistant/components/waqi/config_flow.py | 123 ++++++++++-- homeassistant/components/waqi/strings.json | 22 ++- tests/components/waqi/test_config_flow.py | 191 +++++++++++++++++-- 3 files changed, 301 insertions(+), 35 deletions(-) diff --git a/homeassistant/components/waqi/config_flow.py b/homeassistant/components/waqi/config_flow.py index b5f3a18b223..8404b425678 100644 --- a/homeassistant/components/waqi/config_flow.py +++ b/homeassistant/components/waqi/config_flow.py @@ -1,6 +1,7 @@ """Config flow for World Air Quality Index (WAQI) integration.""" from __future__ import annotations +from collections.abc import Awaitable, Callable import logging from typing import Any @@ -18,25 +19,36 @@ from homeassistant.const import ( CONF_LATITUDE, CONF_LOCATION, CONF_LONGITUDE, + CONF_METHOD, CONF_NAME, ) from homeassistant.core import DOMAIN as HOMEASSISTANT_DOMAIN from homeassistant.data_entry_flow import AbortFlow, FlowResult from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.issue_registry import IssueSeverity, async_create_issue -from homeassistant.helpers.selector import LocationSelector +from homeassistant.helpers.selector import ( + LocationSelector, + SelectSelector, + SelectSelectorConfig, +) from homeassistant.helpers.typing import ConfigType from .const import CONF_STATION_NUMBER, DOMAIN, ISSUE_PLACEHOLDER _LOGGER = logging.getLogger(__name__) +CONF_MAP = "map" + class WAQIConfigFlow(ConfigFlow, domain=DOMAIN): """Handle a config flow for World Air Quality Index (WAQI).""" VERSION = 1 + def __init__(self) -> None: + """Initialize config flow.""" + self.data: dict[str, Any] = {} + async def async_step_user( self, user_input: dict[str, Any] | None = None ) -> FlowResult: @@ -47,13 +59,8 @@ class WAQIConfigFlow(ConfigFlow, domain=DOMAIN): session=async_get_clientsession(self.hass) ) as waqi_client: waqi_client.authenticate(user_input[CONF_API_KEY]) - location = user_input[CONF_LOCATION] try: - measuring_station: WAQIAirQuality = ( - await waqi_client.get_by_coordinates( - location[CONF_LATITUDE], location[CONF_LONGITUDE] - ) - ) + await waqi_client.get_by_ip() except WAQIAuthenticationError: errors["base"] = "invalid_auth" except WAQIConnectionError: @@ -62,36 +69,110 @@ class WAQIConfigFlow(ConfigFlow, domain=DOMAIN): _LOGGER.exception(exc) errors["base"] = "unknown" else: - await self.async_set_unique_id(str(measuring_station.station_id)) - self._abort_if_unique_id_configured() - return self.async_create_entry( - title=measuring_station.city.name, - data={ - CONF_API_KEY: user_input[CONF_API_KEY], - CONF_STATION_NUMBER: measuring_station.station_id, - }, - ) + self.data = user_input + if user_input[CONF_METHOD] == CONF_MAP: + return await self.async_step_map() + return await self.async_step_station_number() return self.async_show_form( step_id="user", - data_schema=self.add_suggested_values_to_schema( + data_schema=vol.Schema( + { + vol.Required(CONF_API_KEY): str, + vol.Required(CONF_METHOD): SelectSelector( + SelectSelectorConfig( + options=[CONF_MAP, CONF_STATION_NUMBER], + translation_key="method", + ) + ), + } + ), + errors=errors, + ) + + async def _async_base_step( + self, + step_id: str, + method: Callable[[WAQIClient, dict[str, Any]], Awaitable[WAQIAirQuality]], + data_schema: vol.Schema, + user_input: dict[str, Any] | None = None, + ) -> FlowResult: + errors: dict[str, str] = {} + if user_input is not None: + async with WAQIClient( + session=async_get_clientsession(self.hass) + ) as waqi_client: + waqi_client.authenticate(self.data[CONF_API_KEY]) + try: + measuring_station = await method(waqi_client, user_input) + except WAQIConnectionError: + errors["base"] = "cannot_connect" + except Exception as exc: # pylint: disable=broad-except + _LOGGER.exception(exc) + errors["base"] = "unknown" + else: + return await self._async_create_entry(measuring_station) + return self.async_show_form( + step_id=step_id, data_schema=data_schema, errors=errors + ) + + async def async_step_map( + self, user_input: dict[str, Any] | None = None + ) -> FlowResult: + """Add measuring station via map.""" + return await self._async_base_step( + CONF_MAP, + lambda waqi_client, data: waqi_client.get_by_coordinates( + data[CONF_LOCATION][CONF_LATITUDE], data[CONF_LOCATION][CONF_LONGITUDE] + ), + self.add_suggested_values_to_schema( vol.Schema( { - vol.Required(CONF_API_KEY): str, vol.Required( CONF_LOCATION, ): LocationSelector(), } ), - user_input - or { + { CONF_LOCATION: { CONF_LATITUDE: self.hass.config.latitude, CONF_LONGITUDE: self.hass.config.longitude, } }, ), - errors=errors, + user_input, + ) + + async def async_step_station_number( + self, user_input: dict[str, Any] | None = None + ) -> FlowResult: + """Add measuring station via station number.""" + return await self._async_base_step( + CONF_STATION_NUMBER, + lambda waqi_client, data: waqi_client.get_by_station_number( + data[CONF_STATION_NUMBER] + ), + vol.Schema( + { + vol.Required( + CONF_STATION_NUMBER, + ): int, + } + ), + user_input, + ) + + async def _async_create_entry( + self, measuring_station: WAQIAirQuality + ) -> FlowResult: + await self.async_set_unique_id(str(measuring_station.station_id)) + self._abort_if_unique_id_configured() + return self.async_create_entry( + title=measuring_station.city.name, + data={ + CONF_API_KEY: self.data[CONF_API_KEY], + CONF_STATION_NUMBER: measuring_station.station_id, + }, ) async def async_step_import(self, import_config: ConfigType) -> FlowResult: diff --git a/homeassistant/components/waqi/strings.json b/homeassistant/components/waqi/strings.json index 4ceb911de9e..46031a3072b 100644 --- a/homeassistant/components/waqi/strings.json +++ b/homeassistant/components/waqi/strings.json @@ -2,10 +2,20 @@ "config": { "step": { "user": { + "data": { + "api_key": "[%key:common::config_flow::data::api_key%]", + "method": "How do you want to select a measuring station?" + } + }, + "map": { "description": "Select a location to get the closest measuring station.", "data": { - "location": "[%key:common::config_flow::data::location%]", - "api_key": "[%key:common::config_flow::data::api_key%]" + "location": "[%key:common::config_flow::data::location%]" + } + }, + "station_number": { + "data": { + "station_number": "Measuring station number" } } }, @@ -18,6 +28,14 @@ "already_configured": "[%key:common::config_flow::abort::already_configured_device%]" } }, + "selector": { + "method": { + "options": { + "map": "Select nearest from point on the map", + "station_number": "Enter a station number" + } + } + }, "issues": { "deprecated_yaml_import_issue_invalid_auth": { "title": "The World Air Quality Index YAML configuration import failed", diff --git a/tests/components/waqi/test_config_flow.py b/tests/components/waqi/test_config_flow.py index 3901ffad550..be738a119e5 100644 --- a/tests/components/waqi/test_config_flow.py +++ b/tests/components/waqi/test_config_flow.py @@ -1,17 +1,20 @@ """Test the World Air Quality Index (WAQI) config flow.""" import json +from typing import Any from unittest.mock import AsyncMock, patch from aiowaqi import WAQIAirQuality, WAQIAuthenticationError, WAQIConnectionError import pytest from homeassistant import config_entries +from homeassistant.components.waqi.config_flow import CONF_MAP from homeassistant.components.waqi.const import CONF_STATION_NUMBER, DOMAIN from homeassistant.const import ( CONF_API_KEY, CONF_LATITUDE, CONF_LOCATION, CONF_LONGITUDE, + CONF_METHOD, ) from homeassistant.core import HomeAssistant from homeassistant.data_entry_flow import FlowResultType @@ -21,7 +24,29 @@ from tests.common import load_fixture pytestmark = pytest.mark.usefixtures("mock_setup_entry") -async def test_full_flow(hass: HomeAssistant, mock_setup_entry: AsyncMock) -> None: +@pytest.mark.parametrize( + ("method", "payload"), + [ + ( + CONF_MAP, + { + CONF_LOCATION: {CONF_LATITUDE: 50.0, CONF_LONGITUDE: 10.0}, + }, + ), + ( + CONF_STATION_NUMBER, + { + CONF_STATION_NUMBER: 4584, + }, + ), + ], +) +async def test_full_map_flow( + hass: HomeAssistant, + mock_setup_entry: AsyncMock, + method: str, + payload: dict[str, Any], +) -> None: """Test we get the form.""" result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": config_entries.SOURCE_USER} @@ -31,17 +56,36 @@ async def test_full_flow(hass: HomeAssistant, mock_setup_entry: AsyncMock) -> No with patch( "aiowaqi.WAQIClient.authenticate", ), patch( - "aiowaqi.WAQIClient.get_by_coordinates", + "aiowaqi.WAQIClient.get_by_ip", return_value=WAQIAirQuality.parse_obj( json.loads(load_fixture("waqi/air_quality_sensor.json")) ), ): result = await hass.config_entries.flow.async_configure( result["flow_id"], - { - CONF_LOCATION: {CONF_LATITUDE: 50.0, CONF_LONGITUDE: 10.0}, - CONF_API_KEY: "asd", - }, + {CONF_API_KEY: "asd", CONF_METHOD: method}, + ) + await hass.async_block_till_done() + + assert result["type"] == FlowResultType.FORM + assert result["step_id"] == method + + with patch( + "aiowaqi.WAQIClient.authenticate", + ), patch( + "aiowaqi.WAQIClient.get_by_coordinates", + return_value=WAQIAirQuality.parse_obj( + json.loads(load_fixture("waqi/air_quality_sensor.json")) + ), + ), patch( + "aiowaqi.WAQIClient.get_by_station_number", + return_value=WAQIAirQuality.parse_obj( + json.loads(load_fixture("waqi/air_quality_sensor.json")) + ), + ): + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + payload, ) await hass.async_block_till_done() @@ -73,21 +117,35 @@ async def test_flow_errors( with patch( "aiowaqi.WAQIClient.authenticate", ), patch( - "aiowaqi.WAQIClient.get_by_coordinates", + "aiowaqi.WAQIClient.get_by_ip", side_effect=exception, ): result = await hass.config_entries.flow.async_configure( result["flow_id"], - { - CONF_LOCATION: {CONF_LATITUDE: 50.0, CONF_LONGITUDE: 10.0}, - CONF_API_KEY: "asd", - }, + {CONF_API_KEY: "asd", CONF_METHOD: CONF_MAP}, ) await hass.async_block_till_done() assert result["type"] == FlowResultType.FORM assert result["errors"] == {"base": error} + with patch( + "aiowaqi.WAQIClient.authenticate", + ), patch( + "aiowaqi.WAQIClient.get_by_ip", + return_value=WAQIAirQuality.parse_obj( + json.loads(load_fixture("waqi/air_quality_sensor.json")) + ), + ): + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + {CONF_API_KEY: "asd", CONF_METHOD: CONF_MAP}, + ) + await hass.async_block_till_done() + + assert result["type"] == FlowResultType.FORM + assert result["step_id"] == "map" + with patch( "aiowaqi.WAQIClient.authenticate", ), patch( @@ -100,9 +158,118 @@ async def test_flow_errors( result["flow_id"], { CONF_LOCATION: {CONF_LATITUDE: 50.0, CONF_LONGITUDE: 10.0}, - CONF_API_KEY: "asd", }, ) await hass.async_block_till_done() assert result["type"] == FlowResultType.CREATE_ENTRY + + +@pytest.mark.parametrize( + ("method", "payload", "exception", "error"), + [ + ( + CONF_MAP, + { + CONF_LOCATION: {CONF_LATITUDE: 50.0, CONF_LONGITUDE: 10.0}, + }, + WAQIConnectionError(), + "cannot_connect", + ), + ( + CONF_MAP, + { + CONF_LOCATION: {CONF_LATITUDE: 50.0, CONF_LONGITUDE: 10.0}, + }, + Exception(), + "unknown", + ), + ( + CONF_STATION_NUMBER, + { + CONF_STATION_NUMBER: 4584, + }, + WAQIConnectionError(), + "cannot_connect", + ), + ( + CONF_STATION_NUMBER, + { + CONF_STATION_NUMBER: 4584, + }, + Exception(), + "unknown", + ), + ], +) +async def test_error_in_second_step( + hass: HomeAssistant, + mock_setup_entry: AsyncMock, + method: str, + payload: dict[str, Any], + exception: Exception, + error: str, +) -> None: + """Test we get the form.""" + result = await hass.config_entries.flow.async_init( + DOMAIN, context={"source": config_entries.SOURCE_USER} + ) + assert result["type"] == FlowResultType.FORM + + with patch( + "aiowaqi.WAQIClient.authenticate", + ), patch( + "aiowaqi.WAQIClient.get_by_ip", + return_value=WAQIAirQuality.parse_obj( + json.loads(load_fixture("waqi/air_quality_sensor.json")) + ), + ): + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + {CONF_API_KEY: "asd", CONF_METHOD: method}, + ) + await hass.async_block_till_done() + + assert result["type"] == FlowResultType.FORM + assert result["step_id"] == method + + with patch( + "aiowaqi.WAQIClient.authenticate", + ), patch( + "aiowaqi.WAQIClient.get_by_coordinates", side_effect=exception + ), patch("aiowaqi.WAQIClient.get_by_station_number", side_effect=exception): + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + payload, + ) + await hass.async_block_till_done() + + assert result["type"] == FlowResultType.FORM + assert result["errors"] == {"base": error} + + with patch( + "aiowaqi.WAQIClient.authenticate", + ), patch( + "aiowaqi.WAQIClient.get_by_coordinates", + return_value=WAQIAirQuality.parse_obj( + json.loads(load_fixture("waqi/air_quality_sensor.json")) + ), + ), patch( + "aiowaqi.WAQIClient.get_by_station_number", + return_value=WAQIAirQuality.parse_obj( + json.loads(load_fixture("waqi/air_quality_sensor.json")) + ), + ): + result = await hass.config_entries.flow.async_configure( + result["flow_id"], + payload, + ) + await hass.async_block_till_done() + + assert result["type"] == FlowResultType.CREATE_ENTRY + assert result["title"] == "de Jongweg, Utrecht" + assert result["data"] == { + CONF_API_KEY: "asd", + CONF_STATION_NUMBER: 4584, + } + assert len(mock_setup_entry.mock_calls) == 1