Geo Location platform code clean up (#18717)

* code cleanup to make use of new externalised feed manager

* fixed lint

* revert change, keep asynctest

* using asynctest

* changed unit test from mocking to inspecting dispatcher signals

* code clean-up
This commit is contained in:
Malte Franken 2018-11-27 23:12:29 +11:00 committed by Paulus Schoutsen
parent 013e181497
commit 61e0e11156
4 changed files with 383 additions and 402 deletions

View File

@ -13,7 +13,8 @@ import voluptuous as vol
from homeassistant.components.geo_location import ( from homeassistant.components.geo_location import (
PLATFORM_SCHEMA, GeoLocationEvent) PLATFORM_SCHEMA, GeoLocationEvent)
from homeassistant.const import ( from homeassistant.const import (
CONF_RADIUS, CONF_SCAN_INTERVAL, CONF_URL, EVENT_HOMEASSISTANT_START) CONF_RADIUS, CONF_SCAN_INTERVAL, CONF_URL, EVENT_HOMEASSISTANT_START,
CONF_LATITUDE, CONF_LONGITUDE)
from homeassistant.core import callback from homeassistant.core import callback
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.dispatcher import ( from homeassistant.helpers.dispatcher import (
@ -38,6 +39,8 @@ SOURCE = 'geo_json_events'
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({
vol.Required(CONF_URL): cv.string, vol.Required(CONF_URL): cv.string,
vol.Optional(CONF_LATITUDE): cv.latitude,
vol.Optional(CONF_LONGITUDE): cv.longitude,
vol.Optional(CONF_RADIUS, default=DEFAULT_RADIUS_IN_KM): vol.Coerce(float), vol.Optional(CONF_RADIUS, default=DEFAULT_RADIUS_IN_KM): vol.Coerce(float),
}) })
@ -46,10 +49,12 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
"""Set up the GeoJSON Events platform.""" """Set up the GeoJSON Events platform."""
url = config[CONF_URL] url = config[CONF_URL]
scan_interval = config.get(CONF_SCAN_INTERVAL, SCAN_INTERVAL) scan_interval = config.get(CONF_SCAN_INTERVAL, SCAN_INTERVAL)
coordinates = (config.get(CONF_LATITUDE, hass.config.latitude),
config.get(CONF_LONGITUDE, hass.config.longitude))
radius_in_km = config[CONF_RADIUS] radius_in_km = config[CONF_RADIUS]
# Initialize the entity manager. # Initialize the entity manager.
feed = GeoJsonFeedManager(hass, add_entities, scan_interval, url, feed = GeoJsonFeedEntityManager(
radius_in_km) hass, add_entities, scan_interval, coordinates, url, radius_in_km)
def start_feed_manager(event): def start_feed_manager(event):
"""Start feed manager.""" """Start feed manager."""
@ -58,87 +63,49 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
hass.bus.listen_once(EVENT_HOMEASSISTANT_START, start_feed_manager) hass.bus.listen_once(EVENT_HOMEASSISTANT_START, start_feed_manager)
class GeoJsonFeedManager: class GeoJsonFeedEntityManager:
"""Feed Manager for GeoJSON feeds.""" """Feed Entity Manager for GeoJSON feeds."""
def __init__(self, hass, add_entities, scan_interval, url, radius_in_km): def __init__(self, hass, add_entities, scan_interval, coordinates, url,
radius_in_km):
"""Initialize the GeoJSON Feed Manager.""" """Initialize the GeoJSON Feed Manager."""
from geojson_client.generic_feed import GenericFeed from geojson_client.generic_feed import GenericFeedManager
self._hass = hass self._hass = hass
self._feed = GenericFeed( self._feed_manager = GenericFeedManager(
(hass.config.latitude, hass.config.longitude), self._generate_entity, self._update_entity, self._remove_entity,
filter_radius=radius_in_km, url=url) coordinates, url, filter_radius=radius_in_km)
self._add_entities = add_entities self._add_entities = add_entities
self._scan_interval = scan_interval self._scan_interval = scan_interval
self.feed_entries = {}
self._managed_external_ids = set()
def startup(self): def startup(self):
"""Start up this manager.""" """Start up this manager."""
self._update() self._feed_manager.update()
self._init_regular_updates() self._init_regular_updates()
def _init_regular_updates(self): def _init_regular_updates(self):
"""Schedule regular updates at the specified interval.""" """Schedule regular updates at the specified interval."""
track_time_interval( track_time_interval(
self._hass, lambda now: self._update(), self._scan_interval) self._hass, lambda now: self._feed_manager.update(),
self._scan_interval)
def _update(self): def get_entry(self, external_id):
"""Update the feed and then update connected entities.""" """Get feed entry by external id."""
import geojson_client return self._feed_manager.feed_entries.get(external_id)
status, feed_entries = self._feed.update() def _generate_entity(self, external_id):
if status == geojson_client.UPDATE_OK: """Generate new entity."""
_LOGGER.debug("Data retrieved %s", feed_entries)
# Keep a copy of all feed entries for future lookups by entities.
self.feed_entries = {entry.external_id: entry
for entry in feed_entries}
# For entity management the external ids from the feed are used.
feed_external_ids = set(self.feed_entries)
remove_external_ids = self._managed_external_ids.difference(
feed_external_ids)
self._remove_entities(remove_external_ids)
update_external_ids = self._managed_external_ids.intersection(
feed_external_ids)
self._update_entities(update_external_ids)
create_external_ids = feed_external_ids.difference(
self._managed_external_ids)
self._generate_new_entities(create_external_ids)
elif status == geojson_client.UPDATE_OK_NO_DATA:
_LOGGER.debug(
"Update successful, but no data received from %s", self._feed)
else:
_LOGGER.warning(
"Update not successful, no data received from %s", self._feed)
# Remove all entities.
self._remove_entities(self._managed_external_ids.copy())
def _generate_new_entities(self, external_ids):
"""Generate new entities for events."""
new_entities = []
for external_id in external_ids:
new_entity = GeoJsonLocationEvent(self, external_id) new_entity = GeoJsonLocationEvent(self, external_id)
_LOGGER.debug("New entity added %s", external_id)
new_entities.append(new_entity)
self._managed_external_ids.add(external_id)
# Add new entities to HA. # Add new entities to HA.
self._add_entities(new_entities, True) self._add_entities([new_entity], True)
def _update_entities(self, external_ids): def _update_entity(self, external_id):
"""Update entities.""" """Update entity."""
for external_id in external_ids: dispatcher_send(self._hass, SIGNAL_UPDATE_ENTITY.format(external_id))
_LOGGER.debug("Existing entity found %s", external_id)
dispatcher_send(
self._hass, SIGNAL_UPDATE_ENTITY.format(external_id))
def _remove_entities(self, external_ids): def _remove_entity(self, external_id):
"""Remove entities.""" """Remove entity."""
for external_id in external_ids: dispatcher_send(self._hass, SIGNAL_DELETE_ENTITY.format(external_id))
_LOGGER.debug("Entity not current anymore %s", external_id)
self._managed_external_ids.remove(external_id)
dispatcher_send(
self._hass, SIGNAL_DELETE_ENTITY.format(external_id))
class GeoJsonLocationEvent(GeoLocationEvent): class GeoJsonLocationEvent(GeoLocationEvent):
@ -184,7 +151,7 @@ class GeoJsonLocationEvent(GeoLocationEvent):
async def async_update(self): async def async_update(self):
"""Update this entity from the data held in the feed manager.""" """Update this entity from the data held in the feed manager."""
_LOGGER.debug("Updating %s", self._external_id) _LOGGER.debug("Updating %s", self._external_id)
feed_entry = self._feed_manager.feed_entries.get(self._external_id) feed_entry = self._feed_manager.get_entry(self._external_id)
if feed_entry: if feed_entry:
self._update_from_feed(feed_entry) self._update_from_feed(feed_entry)

View File

@ -14,7 +14,7 @@ from homeassistant.components.geo_location import (
PLATFORM_SCHEMA, GeoLocationEvent) PLATFORM_SCHEMA, GeoLocationEvent)
from homeassistant.const import ( from homeassistant.const import (
ATTR_ATTRIBUTION, ATTR_LOCATION, CONF_RADIUS, CONF_SCAN_INTERVAL, ATTR_ATTRIBUTION, ATTR_LOCATION, CONF_RADIUS, CONF_SCAN_INTERVAL,
EVENT_HOMEASSISTANT_START) EVENT_HOMEASSISTANT_START, CONF_LATITUDE, CONF_LONGITUDE)
from homeassistant.core import callback from homeassistant.core import callback
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.dispatcher import ( from homeassistant.helpers.dispatcher import (
@ -57,18 +57,23 @@ VALID_CATEGORIES = [
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend({
vol.Optional(CONF_CATEGORIES, default=[]): vol.Optional(CONF_CATEGORIES, default=[]):
vol.All(cv.ensure_list, [vol.In(VALID_CATEGORIES)]), vol.All(cv.ensure_list, [vol.In(VALID_CATEGORIES)]),
vol.Optional(CONF_LATITUDE): cv.latitude,
vol.Optional(CONF_LONGITUDE): cv.longitude,
vol.Optional(CONF_RADIUS, default=DEFAULT_RADIUS_IN_KM): vol.Coerce(float), vol.Optional(CONF_RADIUS, default=DEFAULT_RADIUS_IN_KM): vol.Coerce(float),
}) })
def setup_platform(hass, config, add_entities, discovery_info=None): def setup_platform(hass, config, add_entities, discovery_info=None):
"""Set up the GeoJSON Events platform.""" """Set up the NSW Rural Fire Service Feed platform."""
scan_interval = config.get(CONF_SCAN_INTERVAL, SCAN_INTERVAL) scan_interval = config.get(CONF_SCAN_INTERVAL, SCAN_INTERVAL)
coordinates = (config.get(CONF_LATITUDE, hass.config.latitude),
config.get(CONF_LONGITUDE, hass.config.longitude))
radius_in_km = config[CONF_RADIUS] radius_in_km = config[CONF_RADIUS]
categories = config.get(CONF_CATEGORIES) categories = config.get(CONF_CATEGORIES)
# Initialize the entity manager. # Initialize the entity manager.
feed = NswRuralFireServiceFeedManager( feed = NswRuralFireServiceFeedEntityManager(
hass, add_entities, scan_interval, radius_in_km, categories) hass, add_entities, scan_interval, coordinates, radius_in_km,
categories)
def start_feed_manager(event): def start_feed_manager(event):
"""Start feed manager.""" """Start feed manager."""
@ -77,93 +82,55 @@ def setup_platform(hass, config, add_entities, discovery_info=None):
hass.bus.listen_once(EVENT_HOMEASSISTANT_START, start_feed_manager) hass.bus.listen_once(EVENT_HOMEASSISTANT_START, start_feed_manager)
class NswRuralFireServiceFeedManager: class NswRuralFireServiceFeedEntityManager:
"""Feed Manager for NSW Rural Fire Service GeoJSON feed.""" """Feed Entity Manager for NSW Rural Fire Service GeoJSON feed."""
def __init__(self, hass, add_entities, scan_interval, radius_in_km, def __init__(self, hass, add_entities, scan_interval, coordinates,
categories): radius_in_km, categories):
"""Initialize the GeoJSON Feed Manager.""" """Initialize the Feed Entity Manager."""
from geojson_client.nsw_rural_fire_service_feed \ from geojson_client.nsw_rural_fire_service_feed \
import NswRuralFireServiceFeed import NswRuralFireServiceFeedManager
self._hass = hass self._hass = hass
self._feed = NswRuralFireServiceFeed( self._feed_manager = NswRuralFireServiceFeedManager(
(hass.config.latitude, hass.config.longitude), self._generate_entity, self._update_entity, self._remove_entity,
filter_radius=radius_in_km, filter_categories=categories) coordinates, filter_radius=radius_in_km,
filter_categories=categories)
self._add_entities = add_entities self._add_entities = add_entities
self._scan_interval = scan_interval self._scan_interval = scan_interval
self.feed_entries = {}
self._managed_external_ids = set()
def startup(self): def startup(self):
"""Start up this manager.""" """Start up this manager."""
self._update() self._feed_manager.update()
self._init_regular_updates() self._init_regular_updates()
def _init_regular_updates(self): def _init_regular_updates(self):
"""Schedule regular updates at the specified interval.""" """Schedule regular updates at the specified interval."""
track_time_interval( track_time_interval(
self._hass, lambda now: self._update(), self._scan_interval) self._hass, lambda now: self._feed_manager.update(),
self._scan_interval)
def _update(self): def get_entry(self, external_id):
"""Update the feed and then update connected entities.""" """Get feed entry by external id."""
import geojson_client return self._feed_manager.feed_entries.get(external_id)
status, feed_entries = self._feed.update() def _generate_entity(self, external_id):
if status == geojson_client.UPDATE_OK: """Generate new entity."""
_LOGGER.debug("Data retrieved %s", feed_entries)
# Keep a copy of all feed entries for future lookups by entities.
self.feed_entries = {entry.external_id: entry
for entry in feed_entries}
# For entity management the external ids from the feed are used.
feed_external_ids = set(self.feed_entries)
remove_external_ids = self._managed_external_ids.difference(
feed_external_ids)
self._remove_entities(remove_external_ids)
update_external_ids = self._managed_external_ids.intersection(
feed_external_ids)
self._update_entities(update_external_ids)
create_external_ids = feed_external_ids.difference(
self._managed_external_ids)
self._generate_new_entities(create_external_ids)
elif status == geojson_client.UPDATE_OK_NO_DATA:
_LOGGER.debug(
"Update successful, but no data received from %s", self._feed)
else:
_LOGGER.warning(
"Update not successful, no data received from %s", self._feed)
# Remove all entities.
self._remove_entities(self._managed_external_ids.copy())
def _generate_new_entities(self, external_ids):
"""Generate new entities for events."""
new_entities = []
for external_id in external_ids:
new_entity = NswRuralFireServiceLocationEvent(self, external_id) new_entity = NswRuralFireServiceLocationEvent(self, external_id)
_LOGGER.debug("New entity added %s", external_id)
new_entities.append(new_entity)
self._managed_external_ids.add(external_id)
# Add new entities to HA. # Add new entities to HA.
self._add_entities(new_entities, True) self._add_entities([new_entity], True)
def _update_entities(self, external_ids): def _update_entity(self, external_id):
"""Update entities.""" """Update entity."""
for external_id in external_ids: dispatcher_send(self._hass, SIGNAL_UPDATE_ENTITY.format(external_id))
_LOGGER.debug("Existing entity found %s", external_id)
dispatcher_send(
self._hass, SIGNAL_UPDATE_ENTITY.format(external_id))
def _remove_entities(self, external_ids): def _remove_entity(self, external_id):
"""Remove entities.""" """Remove entity."""
for external_id in external_ids: dispatcher_send(self._hass, SIGNAL_DELETE_ENTITY.format(external_id))
_LOGGER.debug("Entity not current anymore %s", external_id)
self._managed_external_ids.remove(external_id)
dispatcher_send(
self._hass, SIGNAL_DELETE_ENTITY.format(external_id))
class NswRuralFireServiceLocationEvent(GeoLocationEvent): class NswRuralFireServiceLocationEvent(GeoLocationEvent):
"""This represents an external event with GeoJSON data.""" """This represents an external event with NSW Rural Fire Service data."""
def __init__(self, feed_manager, external_id): def __init__(self, feed_manager, external_id):
"""Initialize entity with data from feed entry.""" """Initialize entity with data from feed entry."""
@ -209,13 +176,13 @@ class NswRuralFireServiceLocationEvent(GeoLocationEvent):
@property @property
def should_poll(self): def should_poll(self):
"""No polling needed for GeoJSON location events.""" """No polling needed for NSW Rural Fire Service location events."""
return False return False
async def async_update(self): async def async_update(self):
"""Update this entity from the data held in the feed manager.""" """Update this entity from the data held in the feed manager."""
_LOGGER.debug("Updating %s", self._external_id) _LOGGER.debug("Updating %s", self._external_id)
feed_entry = self._feed_manager.feed_entries.get(self._external_id) feed_entry = self._feed_manager.get_entry(self._external_id)
if feed_entry: if feed_entry:
self._update_from_feed(feed_entry) self._update_from_feed(feed_entry)

View File

@ -1,19 +1,16 @@
"""The tests for the geojson platform.""" """The tests for the geojson platform."""
import unittest from asynctest.mock import patch, MagicMock, call
from unittest import mock
from unittest.mock import patch, MagicMock
import homeassistant
from homeassistant.components import geo_location from homeassistant.components import geo_location
from homeassistant.components.geo_location import ATTR_SOURCE from homeassistant.components.geo_location import ATTR_SOURCE
from homeassistant.components.geo_location.geo_json_events import \ from homeassistant.components.geo_location.geo_json_events import \
SCAN_INTERVAL, ATTR_EXTERNAL_ID SCAN_INTERVAL, ATTR_EXTERNAL_ID, SIGNAL_DELETE_ENTITY, SIGNAL_UPDATE_ENTITY
from homeassistant.const import CONF_URL, EVENT_HOMEASSISTANT_START, \ from homeassistant.const import CONF_URL, EVENT_HOMEASSISTANT_START, \
CONF_RADIUS, ATTR_LATITUDE, ATTR_LONGITUDE, ATTR_FRIENDLY_NAME, \ CONF_RADIUS, ATTR_LATITUDE, ATTR_LONGITUDE, ATTR_FRIENDLY_NAME, \
ATTR_UNIT_OF_MEASUREMENT ATTR_UNIT_OF_MEASUREMENT, CONF_LATITUDE, CONF_LONGITUDE
from homeassistant.setup import setup_component from homeassistant.helpers.dispatcher import DATA_DISPATCHER
from tests.common import get_test_home_assistant, assert_setup_component, \ from homeassistant.setup import async_setup_component
fire_time_changed from tests.common import assert_setup_component, async_fire_time_changed
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
URL = 'http://geo.json.local/geo_json_events.json' URL = 'http://geo.json.local/geo_json_events.json'
@ -27,19 +24,19 @@ CONFIG = {
] ]
} }
CONFIG_WITH_CUSTOM_LOCATION = {
geo_location.DOMAIN: [
{
'platform': 'geo_json_events',
CONF_URL: URL,
CONF_RADIUS: 200,
CONF_LATITUDE: 15.1,
CONF_LONGITUDE: 25.2
}
]
}
class TestGeoJsonPlatform(unittest.TestCase):
"""Test the geojson platform."""
def setUp(self):
"""Initialize values for this testcase class."""
self.hass = get_test_home_assistant()
def tearDown(self):
"""Stop everything that was started."""
self.hass.stop()
@staticmethod
def _generate_mock_feed_entry(external_id, title, distance_to_home, def _generate_mock_feed_entry(external_id, title, distance_to_home,
coordinates): coordinates):
"""Construct a mock feed entry for testing purposes.""" """Construct a mock feed entry for testing purposes."""
@ -50,36 +47,38 @@ class TestGeoJsonPlatform(unittest.TestCase):
feed_entry.coordinates = coordinates feed_entry.coordinates = coordinates
return feed_entry return feed_entry
@mock.patch('geojson_client.generic_feed.GenericFeed')
def test_setup(self, mock_feed): async def test_setup(hass):
"""Test the general setup of the platform.""" """Test the general setup of the platform."""
# Set up some mock feed entries for this test. # Set up some mock feed entries for this test.
mock_entry_1 = self._generate_mock_feed_entry('1234', 'Title 1', 15.5, mock_entry_1 = _generate_mock_feed_entry(
(-31.0, 150.0)) '1234', 'Title 1', 15.5, (-31.0, 150.0))
mock_entry_2 = self._generate_mock_feed_entry('2345', 'Title 2', 20.5, mock_entry_2 = _generate_mock_feed_entry(
(-31.1, 150.1)) '2345', 'Title 2', 20.5, (-31.1, 150.1))
mock_entry_3 = self._generate_mock_feed_entry('3456', 'Title 3', 25.5, mock_entry_3 = _generate_mock_feed_entry(
(-31.2, 150.2)) '3456', 'Title 3', 25.5, (-31.2, 150.2))
mock_entry_4 = self._generate_mock_feed_entry('4567', 'Title 4', 12.5, mock_entry_4 = _generate_mock_feed_entry(
(-31.3, 150.3)) '4567', 'Title 4', 12.5, (-31.3, 150.3))
# Patching 'utcnow' to gain more control over the timed update.
utcnow = dt_util.utcnow()
with patch('homeassistant.util.dt.utcnow', return_value=utcnow), \
patch('geojson_client.generic_feed.GenericFeed') as mock_feed:
mock_feed.return_value.update.return_value = 'OK', [mock_entry_1, mock_feed.return_value.update.return_value = 'OK', [mock_entry_1,
mock_entry_2, mock_entry_2,
mock_entry_3] mock_entry_3]
utcnow = dt_util.utcnow()
# Patching 'utcnow' to gain more control over the timed update.
with patch('homeassistant.util.dt.utcnow', return_value=utcnow):
with assert_setup_component(1, geo_location.DOMAIN): with assert_setup_component(1, geo_location.DOMAIN):
assert setup_component(self.hass, geo_location.DOMAIN, CONFIG) assert await async_setup_component(
hass, geo_location.DOMAIN, CONFIG)
# Artificially trigger update. # Artificially trigger update.
self.hass.bus.fire(EVENT_HOMEASSISTANT_START) hass.bus.async_fire(EVENT_HOMEASSISTANT_START)
# Collect events. # Collect events.
self.hass.block_till_done() await hass.async_block_till_done()
all_states = self.hass.states.all() all_states = hass.states.async_all()
assert len(all_states) == 3 assert len(all_states) == 3
state = self.hass.states.get("geo_location.title_1") state = hass.states.get("geo_location.title_1")
assert state is not None assert state is not None
assert state.name == "Title 1" assert state.name == "Title 1"
assert state.attributes == { assert state.attributes == {
@ -89,7 +88,7 @@ class TestGeoJsonPlatform(unittest.TestCase):
ATTR_SOURCE: 'geo_json_events'} ATTR_SOURCE: 'geo_json_events'}
assert round(abs(float(state.state)-15.5), 7) == 0 assert round(abs(float(state.state)-15.5), 7) == 0
state = self.hass.states.get("geo_location.title_2") state = hass.states.get("geo_location.title_2")
assert state is not None assert state is not None
assert state.name == "Title 2" assert state.name == "Title 2"
assert state.attributes == { assert state.attributes == {
@ -99,7 +98,7 @@ class TestGeoJsonPlatform(unittest.TestCase):
ATTR_SOURCE: 'geo_json_events'} ATTR_SOURCE: 'geo_json_events'}
assert round(abs(float(state.state)-20.5), 7) == 0 assert round(abs(float(state.state)-20.5), 7) == 0
state = self.hass.states.get("geo_location.title_3") state = hass.states.get("geo_location.title_3")
assert state is not None assert state is not None
assert state.name == "Title 3" assert state.name == "Title 3"
assert state.attributes == { assert state.attributes == {
@ -113,34 +112,56 @@ class TestGeoJsonPlatform(unittest.TestCase):
# one outdated entry # one outdated entry
mock_feed.return_value.update.return_value = 'OK', [ mock_feed.return_value.update.return_value = 'OK', [
mock_entry_1, mock_entry_4, mock_entry_3] mock_entry_1, mock_entry_4, mock_entry_3]
fire_time_changed(self.hass, utcnow + SCAN_INTERVAL) async_fire_time_changed(hass, utcnow + SCAN_INTERVAL)
self.hass.block_till_done() await hass.async_block_till_done()
all_states = self.hass.states.all() all_states = hass.states.async_all()
assert len(all_states) == 3 assert len(all_states) == 3
# Simulate an update - empty data, but successful update, # Simulate an update - empty data, but successful update,
# so no changes to entities. # so no changes to entities.
mock_feed.return_value.update.return_value = 'OK_NO_DATA', None mock_feed.return_value.update.return_value = 'OK_NO_DATA', None
# mock_restdata.return_value.data = None async_fire_time_changed(hass, utcnow + 2 * SCAN_INTERVAL)
fire_time_changed(self.hass, utcnow + await hass.async_block_till_done()
2 * SCAN_INTERVAL)
self.hass.block_till_done()
all_states = self.hass.states.all() all_states = hass.states.async_all()
assert len(all_states) == 3 assert len(all_states) == 3
# Simulate an update - empty data, removes all entities # Simulate an update - empty data, removes all entities
mock_feed.return_value.update.return_value = 'ERROR', None mock_feed.return_value.update.return_value = 'ERROR', None
fire_time_changed(self.hass, utcnow + async_fire_time_changed(hass, utcnow + 3 * SCAN_INTERVAL)
2 * SCAN_INTERVAL) await hass.async_block_till_done()
self.hass.block_till_done()
all_states = self.hass.states.all() all_states = hass.states.async_all()
assert len(all_states) == 0 assert len(all_states) == 0
@mock.patch('geojson_client.generic_feed.GenericFeed')
def test_setup_race_condition(self, mock_feed): async def test_setup_with_custom_location(hass):
"""Test the setup with a custom location."""
# Set up some mock feed entries for this test.
mock_entry_1 = _generate_mock_feed_entry(
'1234', 'Title 1', 2000.5, (-31.1, 150.1))
with patch('geojson_client.generic_feed.GenericFeed') as mock_feed:
mock_feed.return_value.update.return_value = 'OK', [mock_entry_1]
with assert_setup_component(1, geo_location.DOMAIN):
assert await async_setup_component(
hass, geo_location.DOMAIN, CONFIG_WITH_CUSTOM_LOCATION)
# Artificially trigger update.
hass.bus.async_fire(EVENT_HOMEASSISTANT_START)
# Collect events.
await hass.async_block_till_done()
all_states = hass.states.async_all()
assert len(all_states) == 1
assert mock_feed.call_args == call(
(15.1, 25.2), URL, filter_radius=200.0)
async def test_setup_race_condition(hass):
"""Test a particular race condition experienced.""" """Test a particular race condition experienced."""
# 1. Feed returns 1 entry -> Feed manager creates 1 entity. # 1. Feed returns 1 entry -> Feed manager creates 1 entity.
# 2. Feed returns error -> Feed manager removes 1 entity. # 2. Feed returns error -> Feed manager removes 1 entity.
@ -153,74 +174,68 @@ class TestGeoJsonPlatform(unittest.TestCase):
# the second attempt fails of course. # the second attempt fails of course.
# Set up some mock feed entries for this test. # Set up some mock feed entries for this test.
mock_entry_1 = self._generate_mock_feed_entry('1234', 'Title 1', 15.5, mock_entry_1 = _generate_mock_feed_entry(
(-31.0, 150.0)) '1234', 'Title 1', 15.5, (-31.0, 150.0))
delete_signal = SIGNAL_DELETE_ENTITY.format('1234')
update_signal = SIGNAL_UPDATE_ENTITY.format('1234')
# Patching 'utcnow' to gain more control over the timed update.
utcnow = dt_util.utcnow()
with patch('homeassistant.util.dt.utcnow', return_value=utcnow), \
patch('geojson_client.generic_feed.GenericFeed') as mock_feed:
with assert_setup_component(1, geo_location.DOMAIN):
assert await async_setup_component(
hass, geo_location.DOMAIN, CONFIG)
mock_feed.return_value.update.return_value = 'OK', [mock_entry_1] mock_feed.return_value.update.return_value = 'OK', [mock_entry_1]
utcnow = dt_util.utcnow()
# Patching 'utcnow' to gain more control over the timed update.
with patch('homeassistant.util.dt.utcnow', return_value=utcnow):
with assert_setup_component(1, geo_location.DOMAIN):
assert setup_component(self.hass, geo_location.DOMAIN, CONFIG)
# This gives us the ability to assert the '_delete_callback'
# has been called while still executing it.
original_delete_callback = homeassistant.components\
.geo_location.geo_json_events.GeoJsonLocationEvent\
._delete_callback
def mock_delete_callback(entity):
original_delete_callback(entity)
with patch('homeassistant.components.geo_location'
'.geo_json_events.GeoJsonLocationEvent'
'._delete_callback',
side_effect=mock_delete_callback,
autospec=True) as mocked_delete_callback:
# Artificially trigger update. # Artificially trigger update.
self.hass.bus.fire(EVENT_HOMEASSISTANT_START) hass.bus.async_fire(EVENT_HOMEASSISTANT_START)
# Collect events. # Collect events.
self.hass.block_till_done() await hass.async_block_till_done()
all_states = self.hass.states.all() all_states = hass.states.async_all()
assert len(all_states) == 1 assert len(all_states) == 1
assert len(hass.data[DATA_DISPATCHER][delete_signal]) == 1
assert len(hass.data[DATA_DISPATCHER][update_signal]) == 1
# Simulate an update - empty data, removes all entities # Simulate an update - empty data, removes all entities
mock_feed.return_value.update.return_value = 'ERROR', None mock_feed.return_value.update.return_value = 'ERROR', None
fire_time_changed(self.hass, utcnow + SCAN_INTERVAL) async_fire_time_changed(hass, utcnow + SCAN_INTERVAL)
self.hass.block_till_done() await hass.async_block_till_done()
assert mocked_delete_callback.call_count == 1 all_states = hass.states.async_all()
all_states = self.hass.states.all()
assert len(all_states) == 0 assert len(all_states) == 0
assert len(hass.data[DATA_DISPATCHER][delete_signal]) == 0
assert len(hass.data[DATA_DISPATCHER][update_signal]) == 0
# Simulate an update - 1 entry # Simulate an update - 1 entry
mock_feed.return_value.update.return_value = 'OK', [ mock_feed.return_value.update.return_value = 'OK', [mock_entry_1]
mock_entry_1] async_fire_time_changed(hass, utcnow + 2 * SCAN_INTERVAL)
fire_time_changed(self.hass, utcnow + 2 * SCAN_INTERVAL) await hass.async_block_till_done()
self.hass.block_till_done()
all_states = self.hass.states.all() all_states = hass.states.async_all()
assert len(all_states) == 1 assert len(all_states) == 1
assert len(hass.data[DATA_DISPATCHER][delete_signal]) == 1
assert len(hass.data[DATA_DISPATCHER][update_signal]) == 1
# Simulate an update - 1 entry # Simulate an update - 1 entry
mock_feed.return_value.update.return_value = 'OK', [ mock_feed.return_value.update.return_value = 'OK', [mock_entry_1]
mock_entry_1] async_fire_time_changed(hass, utcnow + 3 * SCAN_INTERVAL)
fire_time_changed(self.hass, utcnow + 3 * SCAN_INTERVAL) await hass.async_block_till_done()
self.hass.block_till_done()
all_states = self.hass.states.all() all_states = hass.states.async_all()
assert len(all_states) == 1 assert len(all_states) == 1
assert len(hass.data[DATA_DISPATCHER][delete_signal]) == 1
# Reset mocked method for the next test. assert len(hass.data[DATA_DISPATCHER][update_signal]) == 1
mocked_delete_callback.reset_mock()
# Simulate an update - empty data, removes all entities # Simulate an update - empty data, removes all entities
mock_feed.return_value.update.return_value = 'ERROR', None mock_feed.return_value.update.return_value = 'ERROR', None
fire_time_changed(self.hass, utcnow + 4 * SCAN_INTERVAL) async_fire_time_changed(hass, utcnow + 4 * SCAN_INTERVAL)
self.hass.block_till_done() await hass.async_block_till_done()
assert mocked_delete_callback.call_count == 1 all_states = hass.states.async_all()
all_states = self.hass.states.all()
assert len(all_states) == 0 assert len(all_states) == 0
# Ensure that delete and update signal targets are now empty.
assert len(hass.data[DATA_DISPATCHER][delete_signal]) == 0
assert len(hass.data[DATA_DISPATCHER][update_signal]) == 0

View File

@ -1,6 +1,6 @@
"""The tests for the geojson platform.""" """The tests for the geojson platform."""
import datetime import datetime
from asynctest.mock import patch, MagicMock from asynctest.mock import patch, MagicMock, call
from homeassistant.components import geo_location from homeassistant.components import geo_location
from homeassistant.components.geo_location import ATTR_SOURCE from homeassistant.components.geo_location import ATTR_SOURCE
@ -8,24 +8,33 @@ from homeassistant.components.geo_location.nsw_rural_fire_service_feed import \
ATTR_EXTERNAL_ID, SCAN_INTERVAL, ATTR_CATEGORY, ATTR_FIRE, ATTR_LOCATION, \ ATTR_EXTERNAL_ID, SCAN_INTERVAL, ATTR_CATEGORY, ATTR_FIRE, ATTR_LOCATION, \
ATTR_COUNCIL_AREA, ATTR_STATUS, ATTR_TYPE, ATTR_SIZE, \ ATTR_COUNCIL_AREA, ATTR_STATUS, ATTR_TYPE, ATTR_SIZE, \
ATTR_RESPONSIBLE_AGENCY, ATTR_PUBLICATION_DATE ATTR_RESPONSIBLE_AGENCY, ATTR_PUBLICATION_DATE
from homeassistant.const import CONF_URL, EVENT_HOMEASSISTANT_START, \ from homeassistant.const import ATTR_ATTRIBUTION, ATTR_FRIENDLY_NAME, \
CONF_RADIUS, ATTR_LATITUDE, ATTR_LONGITUDE, ATTR_FRIENDLY_NAME, \ ATTR_LATITUDE, ATTR_LONGITUDE, ATTR_UNIT_OF_MEASUREMENT, CONF_LATITUDE, \
ATTR_UNIT_OF_MEASUREMENT, ATTR_ATTRIBUTION CONF_LONGITUDE, CONF_RADIUS, EVENT_HOMEASSISTANT_START
from homeassistant.setup import async_setup_component from homeassistant.setup import async_setup_component
from tests.common import assert_setup_component, async_fire_time_changed from tests.common import assert_setup_component, async_fire_time_changed
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
URL = 'http://geo.json.local/geo_json_events.json'
CONFIG = { CONFIG = {
geo_location.DOMAIN: [ geo_location.DOMAIN: [
{ {
'platform': 'nsw_rural_fire_service_feed', 'platform': 'nsw_rural_fire_service_feed',
CONF_URL: URL,
CONF_RADIUS: 200 CONF_RADIUS: 200
} }
] ]
} }
CONFIG_WITH_CUSTOM_LOCATION = {
geo_location.DOMAIN: [
{
'platform': 'nsw_rural_fire_service_feed',
CONF_RADIUS: 200,
CONF_LATITUDE: 15.1,
CONF_LONGITUDE: 25.2
}
]
}
def _generate_mock_feed_entry(external_id, title, distance_to_home, def _generate_mock_feed_entry(external_id, title, distance_to_home,
coordinates, category=None, location=None, coordinates, category=None, location=None,
@ -55,8 +64,6 @@ def _generate_mock_feed_entry(external_id, title, distance_to_home,
async def test_setup(hass): async def test_setup(hass):
"""Test the general setup of the platform.""" """Test the general setup of the platform."""
# Set up some mock feed entries for this test. # Set up some mock feed entries for this test.
with patch('geojson_client.nsw_rural_fire_service_feed.'
'NswRuralFireServiceFeed') as mock_feed:
mock_entry_1 = _generate_mock_feed_entry( mock_entry_1 = _generate_mock_feed_entry(
'1234', 'Title 1', 15.5, (-31.0, 150.0), category='Category 1', '1234', 'Title 1', 15.5, (-31.0, 150.0), category='Category 1',
location='Location 1', attribution='Attribution 1', location='Location 1', attribution='Attribution 1',
@ -71,13 +78,15 @@ async def test_setup(hass):
(-31.2, 150.2)) (-31.2, 150.2))
mock_entry_4 = _generate_mock_feed_entry('4567', 'Title 4', 12.5, mock_entry_4 = _generate_mock_feed_entry('4567', 'Title 4', 12.5,
(-31.3, 150.3)) (-31.3, 150.3))
mock_feed.return_value.update.return_value = 'OK', [mock_entry_1,
mock_entry_2,
mock_entry_3]
utcnow = dt_util.utcnow() utcnow = dt_util.utcnow()
# Patching 'utcnow' to gain more control over the timed update. # Patching 'utcnow' to gain more control over the timed update.
with patch('homeassistant.util.dt.utcnow', return_value=utcnow): with patch('homeassistant.util.dt.utcnow', return_value=utcnow), \
patch('geojson_client.nsw_rural_fire_service_feed.'
'NswRuralFireServiceFeed') as mock_feed:
mock_feed.return_value.update.return_value = 'OK', [mock_entry_1,
mock_entry_2,
mock_entry_3]
with assert_setup_component(1, geo_location.DOMAIN): with assert_setup_component(1, geo_location.DOMAIN):
assert await async_setup_component( assert await async_setup_component(
hass, geo_location.DOMAIN, CONFIG) hass, geo_location.DOMAIN, CONFIG)
@ -143,9 +152,7 @@ async def test_setup(hass):
# Simulate an update - empty data, but successful update, # Simulate an update - empty data, but successful update,
# so no changes to entities. # so no changes to entities.
mock_feed.return_value.update.return_value = 'OK_NO_DATA', None mock_feed.return_value.update.return_value = 'OK_NO_DATA', None
# mock_restdata.return_value.data = None async_fire_time_changed(hass, utcnow + 2 * SCAN_INTERVAL)
async_fire_time_changed(hass, utcnow +
2 * SCAN_INTERVAL)
await hass.async_block_till_done() await hass.async_block_till_done()
all_states = hass.states.async_all() all_states = hass.states.async_all()
@ -153,9 +160,34 @@ async def test_setup(hass):
# Simulate an update - empty data, removes all entities # Simulate an update - empty data, removes all entities
mock_feed.return_value.update.return_value = 'ERROR', None mock_feed.return_value.update.return_value = 'ERROR', None
async_fire_time_changed(hass, utcnow + async_fire_time_changed(hass, utcnow + 3 * SCAN_INTERVAL)
2 * SCAN_INTERVAL)
await hass.async_block_till_done() await hass.async_block_till_done()
all_states = hass.states.async_all() all_states = hass.states.async_all()
assert len(all_states) == 0 assert len(all_states) == 0
async def test_setup_with_custom_location(hass):
"""Test the setup with a custom location."""
# Set up some mock feed entries for this test.
mock_entry_1 = _generate_mock_feed_entry(
'1234', 'Title 1', 20.5, (-31.1, 150.1))
with patch('geojson_client.nsw_rural_fire_service_feed.'
'NswRuralFireServiceFeed') as mock_feed:
mock_feed.return_value.update.return_value = 'OK', [mock_entry_1]
with assert_setup_component(1, geo_location.DOMAIN):
assert await async_setup_component(
hass, geo_location.DOMAIN, CONFIG_WITH_CUSTOM_LOCATION)
# Artificially trigger update.
hass.bus.async_fire(EVENT_HOMEASSISTANT_START)
# Collect events.
await hass.async_block_till_done()
all_states = hass.states.async_all()
assert len(all_states) == 1
assert mock_feed.call_args == call(
(15.1, 25.2), filter_categories=[], filter_radius=200.0)