From 0e2fa7700d4abd7d12601fbaa20780c1112abdf6 Mon Sep 17 00:00:00 2001 From: Aaron Bach Date: Tue, 24 Mar 2020 12:39:38 -0600 Subject: [PATCH] =?UTF-8?q?Allow=20more=20than=20one=20AirVisual=20config?= =?UTF-8?q?=20entry=20with=20the=20same=20API=20k=E2=80=A6=20(#33072)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Allow more than one AirVisual config entry with the same API key * Add tests * Correctly pop geography * Code review * Code review --- .../airvisual/.translations/en.json | 2 +- .../components/airvisual/__init__.py | 111 +++++++++++------- .../components/airvisual/config_flow.py | 53 ++++----- homeassistant/components/airvisual/const.py | 1 - homeassistant/components/airvisual/sensor.py | 15 ++- .../components/airvisual/strings.json | 2 +- .../components/airvisual/test_config_flow.py | 71 +++++++++-- 7 files changed, 169 insertions(+), 86 deletions(-) diff --git a/homeassistant/components/airvisual/.translations/en.json b/homeassistant/components/airvisual/.translations/en.json index 604baf1feb6..982ed8e13e7 100644 --- a/homeassistant/components/airvisual/.translations/en.json +++ b/homeassistant/components/airvisual/.translations/en.json @@ -1,7 +1,7 @@ { "config": { "abort": { - "already_configured": "This API key is already in use." + "already_configured": "These coordinates have already been registered." }, "error": { "invalid_api_key": "Invalid API key" diff --git a/homeassistant/components/airvisual/__init__.py b/homeassistant/components/airvisual/__init__.py index a48acf7bb34..e234c2b1c67 100644 --- a/homeassistant/components/airvisual/__init__.py +++ b/homeassistant/components/airvisual/__init__.py @@ -1,5 +1,4 @@ """The airvisual component.""" -import asyncio import logging from pyairvisual import Client @@ -23,7 +22,6 @@ from homeassistant.helpers.event import async_track_time_interval from .const import ( CONF_CITY, CONF_COUNTRY, - CONF_GEOGRAPHIES, DATA_CLIENT, DEFAULT_SCAN_INTERVAL, DOMAIN, @@ -36,7 +34,7 @@ DATA_LISTENER = "listener" DEFAULT_OPTIONS = {CONF_SHOW_ON_MAP: True} -CONF_NODE_ID = "node_id" +CONF_GEOGRAPHIES = "geographies" GEOGRAPHY_COORDINATES_SCHEMA = vol.Schema( { @@ -70,34 +68,38 @@ CONFIG_SCHEMA = vol.Schema({DOMAIN: CLOUD_API_SCHEMA}, extra=vol.ALLOW_EXTRA) def async_get_geography_id(geography_dict): """Generate a unique ID from a geography dict.""" if CONF_CITY in geography_dict: - return ",".join( + return ", ".join( ( geography_dict[CONF_CITY], geography_dict[CONF_STATE], geography_dict[CONF_COUNTRY], ) ) - return ",".join( + return ", ".join( (str(geography_dict[CONF_LATITUDE]), str(geography_dict[CONF_LONGITUDE])) ) async def async_setup(hass, config): """Set up the AirVisual component.""" - hass.data[DOMAIN] = {} - hass.data[DOMAIN][DATA_CLIENT] = {} - hass.data[DOMAIN][DATA_LISTENER] = {} + hass.data[DOMAIN] = {DATA_CLIENT: {}, DATA_LISTENER: {}} if DOMAIN not in config: return True conf = config[DOMAIN] - hass.async_create_task( - hass.config_entries.flow.async_init( - DOMAIN, context={"source": SOURCE_IMPORT}, data=conf + for geography in conf.get( + CONF_GEOGRAPHIES, + [{CONF_LATITUDE: hass.config.latitude, CONF_LONGITUDE: hass.config.longitude}], + ): + hass.async_create_task( + hass.config_entries.flow.async_init( + DOMAIN, + context={"source": SOURCE_IMPORT}, + data={CONF_API_KEY: conf[CONF_API_KEY], **geography}, + ) ) - ) return True @@ -144,6 +146,45 @@ async def async_setup_entry(hass, config_entry): return True +async def async_migrate_entry(hass, config_entry): + """Migrate an old config entry.""" + version = config_entry.version + + _LOGGER.debug("Migrating from version %s", version) + + # 1 -> 2: One geography per config entry + if version == 1: + version = config_entry.version = 2 + + # Update the config entry to only include the first geography (there is always + # guaranteed to be at least one): + data = {**config_entry.data} + geographies = data.pop(CONF_GEOGRAPHIES) + first_geography = geographies.pop(0) + first_id = async_get_geography_id(first_geography) + + hass.config_entries.async_update_entry( + config_entry, + unique_id=first_id, + title=f"Cloud API ({first_id})", + data={CONF_API_KEY: config_entry.data[CONF_API_KEY], **first_geography}, + ) + + # For any geographies that remain, create a new config entry for each one: + for geography in geographies: + hass.async_create_task( + hass.config_entries.flow.async_init( + DOMAIN, + context={"source": SOURCE_IMPORT}, + data={CONF_API_KEY: config_entry.data[CONF_API_KEY], **geography}, + ) + ) + + _LOGGER.info("Migration to version %s successful", version) + + return True + + async def async_unload_entry(hass, config_entry): """Unload an AirVisual config entry.""" hass.data[DOMAIN][DATA_CLIENT].pop(config_entry.entry_id) @@ -170,40 +211,28 @@ class AirVisualData: self._client = client self._hass = hass self.data = {} + self.geography_data = config_entry.data + self.geography_id = config_entry.unique_id self.options = config_entry.options - self.geographies = { - async_get_geography_id(geography): geography - for geography in config_entry.data[CONF_GEOGRAPHIES] - } - async def async_update(self): """Get new data for all locations from the AirVisual cloud API.""" - tasks = [] + if CONF_CITY in self.geography_data: + api_coro = self._client.api.city( + self.geography_data[CONF_CITY], + self.geography_data[CONF_STATE], + self.geography_data[CONF_COUNTRY], + ) + else: + api_coro = self._client.api.nearest_city( + self.geography_data[CONF_LATITUDE], self.geography_data[CONF_LONGITUDE], + ) - for geography in self.geographies.values(): - if CONF_CITY in geography: - tasks.append( - self._client.api.city( - geography[CONF_CITY], - geography[CONF_STATE], - geography[CONF_COUNTRY], - ) - ) - else: - tasks.append( - self._client.api.nearest_city( - geography[CONF_LATITUDE], geography[CONF_LONGITUDE], - ) - ) - - results = await asyncio.gather(*tasks, return_exceptions=True) - for geography_id, result in zip(self.geographies, results): - if isinstance(result, AirVisualError): - _LOGGER.error("Error while retrieving data: %s", result) - self.data[geography_id] = {} - continue - self.data[geography_id] = result + try: + self.data[self.geography_id] = await api_coro + except AirVisualError as err: + _LOGGER.error("Error while retrieving data: %s", err) + self.data[self.geography_id] = {} _LOGGER.debug("Received new data") async_dispatcher_send(self._hass, TOPIC_UPDATE) diff --git a/homeassistant/components/airvisual/config_flow.py b/homeassistant/components/airvisual/config_flow.py index 2f961ccfb49..047f585a4ff 100644 --- a/homeassistant/components/airvisual/config_flow.py +++ b/homeassistant/components/airvisual/config_flow.py @@ -1,5 +1,5 @@ """Define a config flow manager for AirVisual.""" -import logging +import asyncio from pyairvisual import Client from pyairvisual.errors import InvalidKeyError @@ -15,15 +15,14 @@ from homeassistant.const import ( from homeassistant.core import callback from homeassistant.helpers import aiohttp_client, config_validation as cv -from .const import CONF_GEOGRAPHIES, DOMAIN # pylint: disable=unused-import - -_LOGGER = logging.getLogger("homeassistant.components.airvisual") +from . import async_get_geography_id +from .const import DOMAIN # pylint: disable=unused-import class AirVisualFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): """Handle an AirVisual config flow.""" - VERSION = 1 + VERSION = 2 CONNECTION_CLASS = config_entries.CONN_CLASS_CLOUD_POLL @property @@ -68,35 +67,33 @@ class AirVisualFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): if not user_input: return await self._show_form() - await self._async_set_unique_id(user_input[CONF_API_KEY]) + geo_id = async_get_geography_id(user_input) + await self._async_set_unique_id(geo_id) websession = aiohttp_client.async_get_clientsession(self.hass) client = Client(websession, api_key=user_input[CONF_API_KEY]) - try: - await client.api.nearest_city() - except InvalidKeyError: - return await self._show_form(errors={CONF_API_KEY: "invalid_api_key"}) - - data = {CONF_API_KEY: user_input[CONF_API_KEY]} - if user_input.get(CONF_GEOGRAPHIES): - data[CONF_GEOGRAPHIES] = user_input[CONF_GEOGRAPHIES] - else: - data[CONF_GEOGRAPHIES] = [ - { - CONF_LATITUDE: user_input.get( - CONF_LATITUDE, self.hass.config.latitude - ), - CONF_LONGITUDE: user_input.get( - CONF_LONGITUDE, self.hass.config.longitude - ), - } - ] - - return self.async_create_entry( - title=f"Cloud API (API key: {user_input[CONF_API_KEY][:4]}...)", data=data + # If this is the first (and only the first) time we've seen this API key, check + # that it's valid: + checked_keys = self.hass.data.setdefault("airvisual_checked_api_keys", set()) + check_keys_lock = self.hass.data.setdefault( + "airvisual_checked_api_keys_lock", asyncio.Lock() ) + async with check_keys_lock: + if user_input[CONF_API_KEY] not in checked_keys: + try: + await client.api.nearest_city() + except InvalidKeyError: + return await self._show_form( + errors={CONF_API_KEY: "invalid_api_key"} + ) + + checked_keys.add(user_input[CONF_API_KEY]) + return self.async_create_entry( + title=f"Cloud API ({geo_id})", data=user_input + ) + class AirVisualOptionsFlowHandler(config_entries.OptionsFlow): """Handle an AirVisual options flow.""" diff --git a/homeassistant/components/airvisual/const.py b/homeassistant/components/airvisual/const.py index ab54e191116..3bfc224a735 100644 --- a/homeassistant/components/airvisual/const.py +++ b/homeassistant/components/airvisual/const.py @@ -5,7 +5,6 @@ DOMAIN = "airvisual" CONF_CITY = "city" CONF_COUNTRY = "country" -CONF_GEOGRAPHIES = "geographies" DATA_CLIENT = "client" diff --git a/homeassistant/components/airvisual/sensor.py b/homeassistant/components/airvisual/sensor.py index 28d2b3f5f86..49a5f53361f 100644 --- a/homeassistant/components/airvisual/sensor.py +++ b/homeassistant/components/airvisual/sensor.py @@ -191,16 +191,19 @@ class AirVisualSensor(Entity): } ) - geography = self._airvisual.geographies[self._geography_id] - if CONF_LATITUDE in geography: + if CONF_LATITUDE in self._airvisual.geography_data: if self._airvisual.options[CONF_SHOW_ON_MAP]: - self._attrs[ATTR_LATITUDE] = geography[CONF_LATITUDE] - self._attrs[ATTR_LONGITUDE] = geography[CONF_LONGITUDE] + self._attrs[ATTR_LATITUDE] = self._airvisual.geography_data[ + CONF_LATITUDE + ] + self._attrs[ATTR_LONGITUDE] = self._airvisual.geography_data[ + CONF_LONGITUDE + ] self._attrs.pop("lati", None) self._attrs.pop("long", None) else: - self._attrs["lati"] = geography[CONF_LATITUDE] - self._attrs["long"] = geography[CONF_LONGITUDE] + self._attrs["lati"] = self._airvisual.geography_data[CONF_LATITUDE] + self._attrs["long"] = self._airvisual.geography_data[CONF_LONGITUDE] self._attrs.pop(ATTR_LATITUDE, None) self._attrs.pop(ATTR_LONGITUDE, None) diff --git a/homeassistant/components/airvisual/strings.json b/homeassistant/components/airvisual/strings.json index 6e94c393da6..8791e6d864d 100644 --- a/homeassistant/components/airvisual/strings.json +++ b/homeassistant/components/airvisual/strings.json @@ -16,7 +16,7 @@ "invalid_api_key": "Invalid API key" }, "abort": { - "already_configured": "This API key is already in use." + "already_configured": "These coordinates have already been registered." } }, "options": { diff --git a/tests/components/airvisual/test_config_flow.py b/tests/components/airvisual/test_config_flow.py index fb32a86a01a..d21aec14fa0 100644 --- a/tests/components/airvisual/test_config_flow.py +++ b/tests/components/airvisual/test_config_flow.py @@ -11,15 +11,22 @@ from homeassistant.const import ( CONF_LONGITUDE, CONF_SHOW_ON_MAP, ) +from homeassistant.setup import async_setup_component from tests.common import MockConfigEntry async def test_duplicate_error(hass): """Test that errors are shown when duplicates are added.""" - conf = {CONF_API_KEY: "abcde12345"} + conf = { + CONF_API_KEY: "abcde12345", + CONF_LATITUDE: 51.528308, + CONF_LONGITUDE: -0.3817765, + } - MockConfigEntry(domain=DOMAIN, unique_id="abcde12345", data=conf).add_to_hass(hass) + MockConfigEntry( + domain=DOMAIN, unique_id="51.528308, -0.3817765", data=conf + ).add_to_hass(hass) result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": SOURCE_USER}, data=conf @@ -31,7 +38,11 @@ async def test_duplicate_error(hass): async def test_invalid_api_key(hass): """Test that invalid credentials throws an error.""" - conf = {CONF_API_KEY: "abcde12345"} + conf = { + CONF_API_KEY: "abcde12345", + CONF_LATITUDE: 51.528308, + CONF_LONGITUDE: -0.3817765, + } with patch( "pyairvisual.api.API.nearest_city", side_effect=InvalidKeyError, @@ -42,6 +53,47 @@ async def test_invalid_api_key(hass): assert result["errors"] == {CONF_API_KEY: "invalid_api_key"} +async def test_migration_1_2(hass): + """Test migrating from version 1 to version 2.""" + conf = { + CONF_API_KEY: "abcde12345", + CONF_GEOGRAPHIES: [ + {CONF_LATITUDE: 51.528308, CONF_LONGITUDE: -0.3817765}, + {CONF_LATITUDE: 35.48847, CONF_LONGITUDE: 137.5263065}, + ], + } + + config_entry = MockConfigEntry( + domain=DOMAIN, version=1, unique_id="abcde12345", data=conf + ) + config_entry.add_to_hass(hass) + + assert len(hass.config_entries.async_entries(DOMAIN)) == 1 + + with patch("pyairvisual.api.API.nearest_city"): + assert await async_setup_component(hass, DOMAIN, {DOMAIN: conf}) + + config_entries = hass.config_entries.async_entries(DOMAIN) + + assert len(config_entries) == 2 + + assert config_entries[0].unique_id == "51.528308, -0.3817765" + assert config_entries[0].title == "Cloud API (51.528308, -0.3817765)" + assert config_entries[0].data == { + CONF_API_KEY: "abcde12345", + CONF_LATITUDE: 51.528308, + CONF_LONGITUDE: -0.3817765, + } + + assert config_entries[1].unique_id == "35.48847, 137.5263065" + assert config_entries[1].title == "Cloud API (35.48847, 137.5263065)" + assert config_entries[1].data == { + CONF_API_KEY: "abcde12345", + CONF_LATITUDE: 35.48847, + CONF_LONGITUDE: 137.5263065, + } + + async def test_options_flow(hass): """Test config flow options.""" conf = {CONF_API_KEY: "abcde12345"} @@ -84,7 +136,8 @@ async def test_step_import(hass): """Test that the import step works.""" conf = { CONF_API_KEY: "abcde12345", - CONF_GEOGRAPHIES: [{CONF_LATITUDE: 51.528308, CONF_LONGITUDE: -0.3817765}], + CONF_LATITUDE: 51.528308, + CONF_LONGITUDE: -0.3817765, } with patch( @@ -95,10 +148,11 @@ async def test_step_import(hass): ) assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY - assert result["title"] == "Cloud API (API key: abcd...)" + assert result["title"] == "Cloud API (51.528308, -0.3817765)" assert result["data"] == { CONF_API_KEY: "abcde12345", - CONF_GEOGRAPHIES: [{CONF_LATITUDE: 51.528308, CONF_LONGITUDE: -0.3817765}], + CONF_LATITUDE: 51.528308, + CONF_LONGITUDE: -0.3817765, } @@ -117,8 +171,9 @@ async def test_step_user(hass): DOMAIN, context={"source": SOURCE_USER}, data=conf ) assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY - assert result["title"] == "Cloud API (API key: abcd...)" + assert result["title"] == "Cloud API (32.87336, -117.22743)" assert result["data"] == { CONF_API_KEY: "abcde12345", - CONF_GEOGRAPHIES: [{CONF_LATITUDE: 32.87336, CONF_LONGITUDE: -117.22743}], + CONF_LATITUDE: 32.87336, + CONF_LONGITUDE: -117.22743, }