Modernize WWLLN config flow (#32194)

* Modernize WWLLN config flow

* Code review

* Update tests
This commit is contained in:
Paulus Schoutsen 2020-03-04 18:23:00 -08:00 committed by GitHub
parent 56cf4e54a9
commit 81810dd920
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 152 additions and 145 deletions

View File

@ -1,7 +1,8 @@
{ {
"config": { "config": {
"error": { "abort": {
"identifier_exists": "Location already registered" "already_configured": "This location is already registered.",
"window_too_small": "A too-small window will cause Home Assistant to miss events."
}, },
"step": { "step": {
"user": { "user": {

View File

@ -1,24 +1,19 @@
"""Support for World Wide Lightning Location Network.""" """Support for World Wide Lightning Location Network."""
import logging
from aiowwlln import Client from aiowwlln import Client
import voluptuous as vol import voluptuous as vol
from homeassistant.config_entries import SOURCE_IMPORT from homeassistant.config_entries import SOURCE_IMPORT
from homeassistant.const import ( from homeassistant.const import CONF_LATITUDE, CONF_LONGITUDE, CONF_RADIUS
CONF_LATITUDE,
CONF_LONGITUDE,
CONF_RADIUS,
CONF_UNIT_SYSTEM,
CONF_UNIT_SYSTEM_IMPERIAL,
CONF_UNIT_SYSTEM_METRIC,
)
from homeassistant.helpers import aiohttp_client, config_validation as cv from homeassistant.helpers import aiohttp_client, config_validation as cv
from .config_flow import configured_instances from .const import (
from .const import CONF_WINDOW, DATA_CLIENT, DEFAULT_RADIUS, DEFAULT_WINDOW, DOMAIN CONF_WINDOW,
DATA_CLIENT,
_LOGGER = logging.getLogger(__name__) DEFAULT_RADIUS,
DEFAULT_WINDOW,
DOMAIN,
LOGGER,
)
CONFIG_SCHEMA = vol.Schema( CONFIG_SCHEMA = vol.Schema(
{ {
@ -28,7 +23,9 @@ CONFIG_SCHEMA = vol.Schema(
vol.Optional(CONF_LONGITUDE): cv.longitude, vol.Optional(CONF_LONGITUDE): cv.longitude,
vol.Optional(CONF_RADIUS, default=DEFAULT_RADIUS): cv.positive_int, vol.Optional(CONF_RADIUS, default=DEFAULT_RADIUS): cv.positive_int,
vol.Optional(CONF_WINDOW, default=DEFAULT_WINDOW): vol.All( vol.Optional(CONF_WINDOW, default=DEFAULT_WINDOW): vol.All(
cv.time_period, cv.positive_timedelta cv.time_period,
cv.positive_timedelta,
lambda value: value.total_seconds(),
), ),
} }
) )
@ -44,36 +41,9 @@ async def async_setup(hass, config):
conf = config[DOMAIN] conf = config[DOMAIN]
latitude = conf.get(CONF_LATITUDE, hass.config.latitude)
longitude = conf.get(CONF_LONGITUDE, hass.config.longitude)
identifier = f"{latitude}, {longitude}"
if identifier in configured_instances(hass):
return True
if conf[CONF_WINDOW] < DEFAULT_WINDOW:
_LOGGER.warning(
"Setting a window smaller than %s seconds may cause Home Assistant \
to miss events",
DEFAULT_WINDOW.total_seconds(),
)
if hass.config.units.name == CONF_UNIT_SYSTEM_IMPERIAL:
unit_system = CONF_UNIT_SYSTEM_IMPERIAL
else:
unit_system = CONF_UNIT_SYSTEM_METRIC
hass.async_create_task( hass.async_create_task(
hass.config_entries.flow.async_init( hass.config_entries.flow.async_init(
DOMAIN, DOMAIN, context={"source": SOURCE_IMPORT}, data=conf
context={"source": SOURCE_IMPORT},
data={
CONF_LATITUDE: latitude,
CONF_LONGITUDE: longitude,
CONF_RADIUS: conf[CONF_RADIUS],
CONF_WINDOW: conf[CONF_WINDOW],
CONF_UNIT_SYSTEM: unit_system,
},
) )
) )
@ -82,6 +52,15 @@ async def async_setup(hass, config):
async def async_setup_entry(hass, config_entry): async def async_setup_entry(hass, config_entry):
"""Set up the WWLLN as config entry.""" """Set up the WWLLN as config entry."""
if not config_entry.unique_id:
hass.config_entries.async_update_entry(
config_entry,
unique_id=(
f"{config_entry.data[CONF_LATITUDE]}, "
f"{config_entry.data[CONF_LONGITUDE]}"
),
)
hass.data[DOMAIN] = {} hass.data[DOMAIN] = {}
hass.data[DOMAIN][DATA_CLIENT] = {} hass.data[DOMAIN][DATA_CLIENT] = {}
@ -112,7 +91,7 @@ async def async_migrate_entry(hass, config_entry):
default_total_seconds = DEFAULT_WINDOW.total_seconds() default_total_seconds = DEFAULT_WINDOW.total_seconds()
_LOGGER.debug("Migrating from version %s", version) LOGGER.debug("Migrating from version %s", version)
# 1 -> 2: Expanding the default window to 1 hour (if needed): # 1 -> 2: Expanding the default window to 1 hour (if needed):
if version == 1: if version == 1:
@ -120,6 +99,6 @@ async def async_migrate_entry(hass, config_entry):
data[CONF_WINDOW] = default_total_seconds data[CONF_WINDOW] = default_total_seconds
version = config_entry.version = 2 version = config_entry.version = 2
hass.config_entries.async_update_entry(config_entry, data=data) hass.config_entries.async_update_entry(config_entry, data=data)
_LOGGER.info("Migration to version %s successful", version) LOGGER.info("Migration to version %s successful", version)
return True return True

View File

@ -2,39 +2,28 @@
import voluptuous as vol import voluptuous as vol
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.const import ( from homeassistant.const import CONF_LATITUDE, CONF_LONGITUDE, CONF_RADIUS
CONF_LATITUDE,
CONF_LONGITUDE,
CONF_RADIUS,
CONF_UNIT_SYSTEM,
CONF_UNIT_SYSTEM_IMPERIAL,
CONF_UNIT_SYSTEM_METRIC,
)
from homeassistant.core import callback
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from .const import CONF_WINDOW, DEFAULT_RADIUS, DEFAULT_WINDOW, DOMAIN from .const import ( # pylint: disable=unused-import
CONF_WINDOW,
DEFAULT_RADIUS,
DEFAULT_WINDOW,
DOMAIN,
LOGGER,
)
@callback class WWLLNFlowHandler(config_entries.ConfigFlow, domain=DOMAIN):
def configured_instances(hass):
"""Return a set of configured WWLLN instances."""
return set(
"{0}, {1}".format(entry.data[CONF_LATITUDE], entry.data[CONF_LONGITUDE])
for entry in hass.config_entries.async_entries(DOMAIN)
)
@config_entries.HANDLERS.register(DOMAIN)
class WWLLNFlowHandler(config_entries.ConfigFlow):
"""Handle a WWLLN config flow.""" """Handle a WWLLN config flow."""
VERSION = 2 VERSION = 2
CONNECTION_CLASS = config_entries.CONN_CLASS_CLOUD_POLL CONNECTION_CLASS = config_entries.CONN_CLASS_CLOUD_POLL
async def _show_form(self, errors=None): @property
"""Show the form to the user.""" def data_schema(self):
data_schema = vol.Schema( """Return the data schema for the user form."""
return vol.Schema(
{ {
vol.Optional( vol.Optional(
CONF_LATITUDE, default=self.hass.config.latitude CONF_LATITUDE, default=self.hass.config.latitude
@ -46,12 +35,26 @@ class WWLLNFlowHandler(config_entries.ConfigFlow):
} }
) )
async def _show_form(self, errors=None):
"""Show the form to the user."""
return self.async_show_form( return self.async_show_form(
step_id="user", data_schema=data_schema, errors=errors or {} step_id="user", data_schema=self.data_schema, errors=errors or {}
) )
async def async_step_import(self, import_config): async def async_step_import(self, import_config):
"""Import a config entry from configuration.yaml.""" """Import a config entry from configuration.yaml."""
default_window_seconds = DEFAULT_WINDOW.total_seconds()
if (
CONF_WINDOW in import_config
and import_config[CONF_WINDOW] < default_window_seconds
):
LOGGER.error(
"Refusing to use too-small window (%s < %s)",
import_config[CONF_WINDOW],
default_window_seconds,
)
return self.async_abort(reason="window_too_small")
return await self.async_step_user(import_config) return await self.async_step_user(import_config)
async def async_step_user(self, user_input=None): async def async_step_user(self, user_input=None):
@ -59,25 +62,22 @@ class WWLLNFlowHandler(config_entries.ConfigFlow):
if not user_input: if not user_input:
return await self._show_form() return await self._show_form()
identifier = "{0}, {1}".format( latitude = user_input.get(CONF_LATITUDE, self.hass.config.latitude)
user_input[CONF_LATITUDE], user_input[CONF_LONGITUDE] longitude = user_input.get(CONF_LONGITUDE, self.hass.config.longitude)
identifier = f"{latitude}, {longitude}"
await self.async_set_unique_id(identifier)
self._abort_if_unique_id_configured()
return self.async_create_entry(
title=identifier,
data={
CONF_LATITUDE: latitude,
CONF_LONGITUDE: longitude,
CONF_RADIUS: user_input.get(CONF_RADIUS, DEFAULT_RADIUS),
CONF_WINDOW: user_input.get(
CONF_WINDOW, DEFAULT_WINDOW.total_seconds()
),
},
) )
if identifier in configured_instances(self.hass):
return await self._show_form({"base": "identifier_exists"})
if self.hass.config.units.name == CONF_UNIT_SYSTEM_IMPERIAL:
user_input[CONF_UNIT_SYSTEM] = CONF_UNIT_SYSTEM_IMPERIAL
else:
user_input[CONF_UNIT_SYSTEM] = CONF_UNIT_SYSTEM_METRIC
# When importing from `configuration.yaml`, we give the user
# flexibility by allowing the `window` parameter to be any type
# of time period. This will always return a timedelta; unfortunately,
# timedeltas aren't JSON-serializable, so we can't store them in a
# config entry as-is; instead, we save the total seconds as an int:
if CONF_WINDOW in user_input:
user_input[CONF_WINDOW] = user_input[CONF_WINDOW].total_seconds()
else:
user_input[CONF_WINDOW] = DEFAULT_WINDOW.total_seconds()
return self.async_create_entry(title=identifier, data=user_input)

View File

@ -1,5 +1,8 @@
"""Define constants for the WWLLN integration.""" """Define constants for the WWLLN integration."""
from datetime import timedelta from datetime import timedelta
import logging
LOGGER = logging.getLogger(__package__)
DOMAIN = "wwlln" DOMAIN = "wwlln"

View File

@ -1,6 +1,5 @@
"""Support for WWLLN geo location events.""" """Support for WWLLN geo location events."""
from datetime import timedelta from datetime import timedelta
import logging
from aiowwlln.errors import WWLLNError from aiowwlln.errors import WWLLNError
@ -10,7 +9,6 @@ from homeassistant.const import (
CONF_LATITUDE, CONF_LATITUDE,
CONF_LONGITUDE, CONF_LONGITUDE,
CONF_RADIUS, CONF_RADIUS,
CONF_UNIT_SYSTEM,
CONF_UNIT_SYSTEM_IMPERIAL, CONF_UNIT_SYSTEM_IMPERIAL,
LENGTH_KILOMETERS, LENGTH_KILOMETERS,
LENGTH_MILES, LENGTH_MILES,
@ -23,9 +21,7 @@ from homeassistant.helpers.dispatcher import (
from homeassistant.helpers.event import async_track_time_interval from homeassistant.helpers.event import async_track_time_interval
from homeassistant.util.dt import utc_from_timestamp from homeassistant.util.dt import utc_from_timestamp
from .const import CONF_WINDOW, DATA_CLIENT, DOMAIN from .const import CONF_WINDOW, DATA_CLIENT, DOMAIN, LOGGER
_LOGGER = logging.getLogger(__name__)
ATTR_EXTERNAL_ID = "external_id" ATTR_EXTERNAL_ID = "external_id"
ATTR_PUBLICATION_DATE = "publication_date" ATTR_PUBLICATION_DATE = "publication_date"
@ -49,7 +45,6 @@ async def async_setup_entry(hass, entry, async_add_entities):
entry.data[CONF_LONGITUDE], entry.data[CONF_LONGITUDE],
entry.data[CONF_RADIUS], entry.data[CONF_RADIUS],
entry.data[CONF_WINDOW], entry.data[CONF_WINDOW],
entry.data[CONF_UNIT_SYSTEM],
) )
await manager.async_init() await manager.async_init()
@ -66,7 +61,6 @@ class WWLLNEventManager:
longitude, longitude,
radius, radius,
window_seconds, window_seconds,
unit_system,
): ):
"""Initialize.""" """Initialize."""
self._async_add_entities = async_add_entities self._async_add_entities = async_add_entities
@ -79,8 +73,7 @@ class WWLLNEventManager:
self._strikes = {} self._strikes = {}
self._window = timedelta(seconds=window_seconds) self._window = timedelta(seconds=window_seconds)
self._unit_system = unit_system if hass.config.units.name == CONF_UNIT_SYSTEM_IMPERIAL:
if unit_system == CONF_UNIT_SYSTEM_IMPERIAL:
self._unit = LENGTH_MILES self._unit = LENGTH_MILES
else: else:
self._unit = LENGTH_KILOMETERS self._unit = LENGTH_KILOMETERS
@ -88,7 +81,7 @@ class WWLLNEventManager:
@callback @callback
def _create_events(self, ids_to_create): def _create_events(self, ids_to_create):
"""Create new geo location events.""" """Create new geo location events."""
_LOGGER.debug("Going to create %s", ids_to_create) LOGGER.debug("Going to create %s", ids_to_create)
events = [] events = []
for strike_id in ids_to_create: for strike_id in ids_to_create:
strike = self._strikes[strike_id] strike = self._strikes[strike_id]
@ -107,7 +100,7 @@ class WWLLNEventManager:
@callback @callback
def _remove_events(self, ids_to_remove): def _remove_events(self, ids_to_remove):
"""Remove old geo location events.""" """Remove old geo location events."""
_LOGGER.debug("Going to remove %s", ids_to_remove) LOGGER.debug("Going to remove %s", ids_to_remove)
for strike_id in ids_to_remove: for strike_id in ids_to_remove:
async_dispatcher_send(self._hass, SIGNAL_DELETE_ENTITY.format(strike_id)) async_dispatcher_send(self._hass, SIGNAL_DELETE_ENTITY.format(strike_id))
@ -123,18 +116,18 @@ class WWLLNEventManager:
async def async_update(self): async def async_update(self):
"""Refresh data.""" """Refresh data."""
_LOGGER.debug("Refreshing WWLLN data") LOGGER.debug("Refreshing WWLLN data")
try: try:
self._strikes = await self._client.within_radius( self._strikes = await self._client.within_radius(
self._latitude, self._latitude,
self._longitude, self._longitude,
self._radius, self._radius,
unit=self._unit_system, unit=self._hass.config.units.name,
window=self._window, window=self._window,
) )
except WWLLNError as err: except WWLLNError as err:
_LOGGER.error("Error while updating WWLLN data: %s", err) LOGGER.error("Error while updating WWLLN data: %s", err)
return return
new_strike_ids = set(self._strikes) new_strike_ids = set(self._strikes)

View File

@ -11,8 +11,9 @@
} }
} }
}, },
"error": { "abort": {
"identifier_exists": "Location already registered" "already_configured": "This location is already registered.",
"window_too_small": "A too-small window will cause Home Assistant to miss events."
} }
} }
} }

View File

@ -1,6 +1,4 @@
"""Define tests for the WWLLN config flow.""" """Define tests for the WWLLN config flow."""
from datetime import timedelta
from asynctest import patch from asynctest import patch
from homeassistant import data_entry_flow from homeassistant import data_entry_flow
@ -9,34 +7,34 @@ from homeassistant.components.wwlln import (
DATA_CLIENT, DATA_CLIENT,
DOMAIN, DOMAIN,
async_setup_entry, async_setup_entry,
config_flow,
)
from homeassistant.const import (
CONF_LATITUDE,
CONF_LONGITUDE,
CONF_RADIUS,
CONF_UNIT_SYSTEM,
) )
from homeassistant.config_entries import SOURCE_IMPORT, SOURCE_USER
from homeassistant.const import CONF_LATITUDE, CONF_LONGITUDE, CONF_RADIUS
from tests.common import MockConfigEntry
async def test_duplicate_error(hass, config_entry): async def test_duplicate_error(hass, config_entry):
"""Test that errors are shown when duplicates are added.""" """Test that errors are shown when duplicates are added."""
conf = {CONF_LATITUDE: 39.128712, CONF_LONGITUDE: -104.9812612, CONF_RADIUS: 25} conf = {CONF_LATITUDE: 39.128712, CONF_LONGITUDE: -104.9812612, CONF_RADIUS: 25}
config_entry.add_to_hass(hass) MockConfigEntry(
flow = config_flow.WWLLNFlowHandler() domain=DOMAIN, unique_id="39.128712, -104.9812612", data=conf
flow.hass = hass ).add_to_hass(hass)
result = await flow.async_step_user(user_input=conf) result = await hass.config_entries.flow.async_init(
assert result["errors"] == {"base": "identifier_exists"} DOMAIN, context={"source": SOURCE_USER}, data=conf
)
assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT
assert result["reason"] == "already_configured"
async def test_show_form(hass): async def test_show_form(hass):
"""Test that the form is served with no input.""" """Test that the form is served with no input."""
flow = config_flow.WWLLNFlowHandler() result = await hass.config_entries.flow.async_init(
flow.hass = hass DOMAIN, context={"source": SOURCE_USER},
)
result = await flow.async_step_user(user_input=None)
assert result["type"] == data_entry_flow.RESULT_TYPE_FORM assert result["type"] == data_entry_flow.RESULT_TYPE_FORM
assert result["step_id"] == "user" assert result["step_id"] == "user"
@ -44,46 +42,79 @@ async def test_show_form(hass):
async def test_step_import(hass): async def test_step_import(hass):
"""Test that the import step works.""" """Test that the import step works."""
# `configuration.yaml` will always return a timedelta for the `window`
# parameter, FYI:
conf = { conf = {
CONF_LATITUDE: 39.128712, CONF_LATITUDE: 39.128712,
CONF_LONGITUDE: -104.9812612, CONF_LONGITUDE: -104.9812612,
CONF_RADIUS: 25, CONF_RADIUS: 25,
CONF_UNIT_SYSTEM: "metric",
CONF_WINDOW: timedelta(minutes=10),
} }
flow = config_flow.WWLLNFlowHandler() result = await hass.config_entries.flow.async_init(
flow.hass = hass DOMAIN, context={"source": SOURCE_IMPORT}, data=conf
)
result = await flow.async_step_import(import_config=conf)
assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
assert result["title"] == "39.128712, -104.9812612" assert result["title"] == "39.128712, -104.9812612"
assert result["data"] == { assert result["data"] == {
CONF_LATITUDE: 39.128712, CONF_LATITUDE: 39.128712,
CONF_LONGITUDE: -104.9812612, CONF_LONGITUDE: -104.9812612,
CONF_RADIUS: 25, CONF_RADIUS: 25,
CONF_UNIT_SYSTEM: "metric", CONF_WINDOW: 3600.0,
CONF_WINDOW: 600.0,
} }
async def test_step_import_too_small_window(hass):
"""Test that the import step with a too-small window is aborted."""
conf = {
CONF_LATITUDE: 39.128712,
CONF_LONGITUDE: -104.9812612,
CONF_RADIUS: 25,
CONF_WINDOW: 60,
}
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_IMPORT}, data=conf
)
assert result["type"] == data_entry_flow.RESULT_TYPE_ABORT
assert result["reason"] == "window_too_small"
async def test_step_user(hass): async def test_step_user(hass):
"""Test that the user step works.""" """Test that the user step works."""
conf = {CONF_LATITUDE: 39.128712, CONF_LONGITUDE: -104.9812612, CONF_RADIUS: 25} conf = {CONF_LATITUDE: 39.128712, CONF_LONGITUDE: -104.9812612, CONF_RADIUS: 25}
flow = config_flow.WWLLNFlowHandler() result = await hass.config_entries.flow.async_init(
flow.hass = hass DOMAIN, context={"source": SOURCE_USER}, data=conf
)
assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
assert result["title"] == "39.128712, -104.9812612"
assert result["data"] == {
CONF_LATITUDE: 39.128712,
CONF_LONGITUDE: -104.9812612,
CONF_RADIUS: 25,
CONF_WINDOW: 3600.0,
}
async def test_different_unit_system(hass):
"""Test that the config flow picks up the HASS unit system."""
conf = {
CONF_LATITUDE: 39.128712,
CONF_LONGITUDE: -104.9812612,
CONF_RADIUS: 25,
}
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": SOURCE_USER}, data=conf
)
result = await flow.async_step_user(user_input=conf)
assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
assert result["title"] == "39.128712, -104.9812612" assert result["title"] == "39.128712, -104.9812612"
assert result["data"] == { assert result["data"] == {
CONF_LATITUDE: 39.128712, CONF_LATITUDE: 39.128712,
CONF_LONGITUDE: -104.9812612, CONF_LONGITUDE: -104.9812612,
CONF_RADIUS: 25, CONF_RADIUS: 25,
CONF_UNIT_SYSTEM: "metric",
CONF_WINDOW: 3600.0, CONF_WINDOW: 3600.0,
} }
@ -94,20 +125,19 @@ async def test_custom_window(hass):
CONF_LATITUDE: 39.128712, CONF_LATITUDE: 39.128712,
CONF_LONGITUDE: -104.9812612, CONF_LONGITUDE: -104.9812612,
CONF_RADIUS: 25, CONF_RADIUS: 25,
CONF_WINDOW: timedelta(hours=2), CONF_WINDOW: 7200,
} }
flow = config_flow.WWLLNFlowHandler() result = await hass.config_entries.flow.async_init(
flow.hass = hass DOMAIN, context={"source": SOURCE_USER}, data=conf
)
result = await flow.async_step_user(user_input=conf)
assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY assert result["type"] == data_entry_flow.RESULT_TYPE_CREATE_ENTRY
assert result["title"] == "39.128712, -104.9812612" assert result["title"] == "39.128712, -104.9812612"
assert result["data"] == { assert result["data"] == {
CONF_LATITUDE: 39.128712, CONF_LATITUDE: 39.128712,
CONF_LONGITUDE: -104.9812612, CONF_LONGITUDE: -104.9812612,
CONF_RADIUS: 25, CONF_RADIUS: 25,
CONF_UNIT_SYSTEM: "metric",
CONF_WINDOW: 7200, CONF_WINDOW: 7200,
} }