diff --git a/homeassistant/components/nsw_rural_fire_service_feed/geo_location.py b/homeassistant/components/nsw_rural_fire_service_feed/geo_location.py index 9a9679f9575..a04d2bd69b2 100644 --- a/homeassistant/components/nsw_rural_fire_service_feed/geo_location.py +++ b/homeassistant/components/nsw_rural_fire_service_feed/geo_location.py @@ -3,6 +3,7 @@ from datetime import timedelta import logging from typing import Optional +from aio_geojson_nsw_rfs_incidents import NswRuralFireServiceIncidentsFeedManager import voluptuous as vol from homeassistant.components.geo_location import PLATFORM_SCHEMA, GeolocationEvent @@ -14,11 +15,16 @@ from homeassistant.const import ( CONF_RADIUS, CONF_SCAN_INTERVAL, EVENT_HOMEASSISTANT_START, + EVENT_HOMEASSISTANT_STOP, ) from homeassistant.core import callback -import homeassistant.helpers.config_validation as cv -from homeassistant.helpers.dispatcher import async_dispatcher_connect, dispatcher_send -from homeassistant.helpers.event import track_time_interval +from homeassistant.helpers import ConfigType, aiohttp_client, config_validation as cv +from homeassistant.helpers.dispatcher import ( + async_dispatcher_connect, + async_dispatcher_send, +) +from homeassistant.helpers.event import async_track_time_interval +from homeassistant.helpers.typing import HomeAssistantType _LOGGER = logging.getLogger(__name__) @@ -58,7 +64,9 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend( ) -def setup_platform(hass, config, add_entities, discovery_info=None): +async def async_setup_platform( + hass: HomeAssistantType, config: ConfigType, async_add_entities, discovery_info=None +): """Set up the NSW Rural Fire Service Feed platform.""" scan_interval = config.get(CONF_SCAN_INTERVAL, SCAN_INTERVAL) coordinates = ( @@ -68,30 +76,40 @@ def setup_platform(hass, config, add_entities, discovery_info=None): radius_in_km = config[CONF_RADIUS] categories = config.get(CONF_CATEGORIES) # Initialize the entity manager. - feed = NswRuralFireServiceFeedEntityManager( - hass, add_entities, scan_interval, coordinates, radius_in_km, categories + manager = NswRuralFireServiceFeedEntityManager( + hass, async_add_entities, scan_interval, coordinates, radius_in_km, categories ) - def start_feed_manager(event): + async def start_feed_manager(event): """Start feed manager.""" - feed.startup() + await manager.async_init() - hass.bus.listen_once(EVENT_HOMEASSISTANT_START, start_feed_manager) + async def stop_feed_manager(event): + """Stop feed manager.""" + await manager.async_stop() + + hass.bus.async_listen_once(EVENT_HOMEASSISTANT_START, start_feed_manager) + hass.bus.async_listen_once(EVENT_HOMEASSISTANT_STOP, stop_feed_manager) + hass.async_create_task(manager.async_update()) class NswRuralFireServiceFeedEntityManager: """Feed Entity Manager for NSW Rural Fire Service GeoJSON feed.""" def __init__( - self, hass, add_entities, scan_interval, coordinates, radius_in_km, categories + self, + hass, + async_add_entities, + scan_interval, + coordinates, + radius_in_km, + categories, ): """Initialize the Feed Entity Manager.""" - from geojson_client.nsw_rural_fire_service_feed import ( - NswRuralFireServiceFeedManager, - ) - self._hass = hass - self._feed_manager = NswRuralFireServiceFeedManager( + websession = aiohttp_client.async_get_clientsession(hass) + self._feed_manager = NswRuralFireServiceIncidentsFeedManager( + websession, self._generate_entity, self._update_entity, self._remove_entity, @@ -99,37 +117,52 @@ class NswRuralFireServiceFeedEntityManager: filter_radius=radius_in_km, filter_categories=categories, ) - self._add_entities = add_entities + self._async_add_entities = async_add_entities self._scan_interval = scan_interval + self._track_time_remove_callback = None - def startup(self): - """Start up this manager.""" - self._feed_manager.update() - self._init_regular_updates() + async def async_init(self): + """Schedule initial and regular updates based on configured time interval.""" - def _init_regular_updates(self): - """Schedule regular updates at the specified interval.""" - track_time_interval( - self._hass, lambda now: self._feed_manager.update(), self._scan_interval + async def update(event_time): + """Update.""" + await self.async_update() + + # Trigger updates at regular intervals. + self._track_time_remove_callback = async_track_time_interval( + self._hass, update, self._scan_interval ) + _LOGGER.debug("Feed entity manager initialized") + + async def async_update(self): + """Refresh data.""" + await self._feed_manager.update() + _LOGGER.debug("Feed entity manager updated") + + async def async_stop(self): + """Stop this feed entity manager from refreshing.""" + if self._track_time_remove_callback: + self._track_time_remove_callback() + _LOGGER.debug("Feed entity manager stopped") + def get_entry(self, external_id): """Get feed entry by external id.""" return self._feed_manager.feed_entries.get(external_id) - def _generate_entity(self, external_id): + async def _generate_entity(self, external_id): """Generate new entity.""" new_entity = NswRuralFireServiceLocationEvent(self, external_id) # Add new entities to HA. - self._add_entities([new_entity], True) + self._async_add_entities([new_entity], True) - def _update_entity(self, external_id): + async def _update_entity(self, external_id): """Update entity.""" - dispatcher_send(self._hass, SIGNAL_UPDATE_ENTITY.format(external_id)) + async_dispatcher_send(self._hass, SIGNAL_UPDATE_ENTITY.format(external_id)) - def _remove_entity(self, external_id): + async def _remove_entity(self, external_id): """Remove entity.""" - dispatcher_send(self._hass, SIGNAL_DELETE_ENTITY.format(external_id)) + async_dispatcher_send(self._hass, SIGNAL_DELETE_ENTITY.format(external_id)) class NswRuralFireServiceLocationEvent(GeolocationEvent): @@ -169,11 +202,14 @@ class NswRuralFireServiceLocationEvent(GeolocationEvent): self._update_callback, ) + async def async_will_remove_from_hass(self) -> None: + """Call when entity will be removed from hass.""" + self._remove_signal_delete() + self._remove_signal_update() + @callback def _delete_callback(self): """Remove this entity.""" - self._remove_signal_delete() - self._remove_signal_update() self.hass.async_create_task(self.async_remove()) @callback diff --git a/homeassistant/components/nsw_rural_fire_service_feed/manifest.json b/homeassistant/components/nsw_rural_fire_service_feed/manifest.json index 3d16f0a57e3..7dd7d10d6be 100644 --- a/homeassistant/components/nsw_rural_fire_service_feed/manifest.json +++ b/homeassistant/components/nsw_rural_fire_service_feed/manifest.json @@ -3,7 +3,7 @@ "name": "Nsw rural fire service feed", "documentation": "https://www.home-assistant.io/integrations/nsw_rural_fire_service_feed", "requirements": [ - "geojson_client==0.4" + "aio_geojson_nsw_rfs_incidents==0.1" ], "dependencies": [], "codeowners": [ diff --git a/requirements_all.txt b/requirements_all.txt index 650de67eee5..c774e45fbd7 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -129,6 +129,9 @@ aio_geojson_geonetnz_quakes==0.11 # homeassistant.components.geonetnz_volcano aio_geojson_geonetnz_volcano==0.5 +# homeassistant.components.nsw_rural_fire_service_feed +aio_geojson_nsw_rfs_incidents==0.1 + # homeassistant.components.ambient_station aioambient==0.3.2 @@ -550,7 +553,6 @@ geizhals==0.0.9 geniushub-client==0.6.30 # homeassistant.components.geo_json_events -# homeassistant.components.nsw_rural_fire_service_feed # homeassistant.components.usgs_earthquakes_feed geojson_client==0.4 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index 3b6b65d7e4d..3f47fd5317e 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -40,6 +40,9 @@ aio_geojson_geonetnz_quakes==0.11 # homeassistant.components.geonetnz_volcano aio_geojson_geonetnz_volcano==0.5 +# homeassistant.components.nsw_rural_fire_service_feed +aio_geojson_nsw_rfs_incidents==0.1 + # homeassistant.components.ambient_station aioambient==0.3.2 @@ -171,7 +174,6 @@ foobot_async==0.3.1 gTTS-token==1.1.3 # homeassistant.components.geo_json_events -# homeassistant.components.nsw_rural_fire_service_feed # homeassistant.components.usgs_earthquakes_feed geojson_client==0.4 diff --git a/tests/components/nsw_rural_fire_service_feed/test_geo_location.py b/tests/components/nsw_rural_fire_service_feed/test_geo_location.py index f5f88087010..274ef3d3743 100644 --- a/tests/components/nsw_rural_fire_service_feed/test_geo_location.py +++ b/tests/components/nsw_rural_fire_service_feed/test_geo_location.py @@ -1,5 +1,8 @@ -"""The tests for the geojson platform.""" +"""The tests for the NSW Rural Fire Service Feeds platform.""" import datetime +from unittest.mock import ANY + +from aio_geojson_nsw_rfs_incidents import NswRuralFireServiceIncidentsFeed from asynctest.mock import patch, MagicMock, call from homeassistant.components import geo_location @@ -20,6 +23,7 @@ from homeassistant.components.nsw_rural_fire_service_feed.geo_location import ( from homeassistant.const import ( ATTR_ATTRIBUTION, ATTR_FRIENDLY_NAME, + ATTR_ICON, ATTR_LATITUDE, ATTR_LONGITUDE, ATTR_UNIT_OF_MEASUREMENT, @@ -27,7 +31,7 @@ from homeassistant.const import ( CONF_LONGITUDE, CONF_RADIUS, EVENT_HOMEASSISTANT_START, - ATTR_ICON, + EVENT_HOMEASSISTANT_STOP, ) from homeassistant.setup import async_setup_component from tests.common import assert_setup_component, async_fire_time_changed @@ -110,12 +114,12 @@ async def test_setup(hass): mock_entry_3 = _generate_mock_feed_entry("3456", "Title 3", 25.5, (-31.2, 150.2)) mock_entry_4 = _generate_mock_feed_entry("4567", "Title 4", 12.5, (-31.3, 150.3)) - utcnow = dt_util.utcnow() # Patching 'utcnow' to gain more control over the timed update. + utcnow = dt_util.utcnow() with patch("homeassistant.util.dt.utcnow", return_value=utcnow), patch( - "geojson_client.nsw_rural_fire_service_feed." "NswRuralFireServiceFeed" - ) as mock_feed: - mock_feed.return_value.update.return_value = ( + "aio_geojson_client.feed.GeoJsonFeed.update" + ) as mock_feed_update: + mock_feed_update.return_value = ( "OK", [mock_entry_1, mock_entry_2, mock_entry_3], ) @@ -187,7 +191,7 @@ async def test_setup(hass): # Simulate an update - one existing, one new entry, # one outdated entry - mock_feed.return_value.update.return_value = ( + mock_feed_update.return_value = ( "OK", [mock_entry_1, mock_entry_4, mock_entry_3], ) @@ -199,7 +203,7 @@ async def test_setup(hass): # Simulate an update - empty data, but successful update, # so no changes to entities. - mock_feed.return_value.update.return_value = "OK_NO_DATA", None + mock_feed_update.return_value = "OK_NO_DATA", None async_fire_time_changed(hass, utcnow + 2 * SCAN_INTERVAL) await hass.async_block_till_done() @@ -207,13 +211,18 @@ async def test_setup(hass): assert len(all_states) == 3 # Simulate an update - empty data, removes all entities - mock_feed.return_value.update.return_value = "ERROR", None + mock_feed_update.return_value = "ERROR", None async_fire_time_changed(hass, utcnow + 3 * SCAN_INTERVAL) await hass.async_block_till_done() all_states = hass.states.async_all() assert len(all_states) == 0 + # Artificially trigger update. + hass.bus.async_fire(EVENT_HOMEASSISTANT_STOP) + # Collect events. + await hass.async_block_till_done() + async def test_setup_with_custom_location(hass): """Test the setup with a custom location.""" @@ -221,9 +230,12 @@ async def test_setup_with_custom_location(hass): mock_entry_1 = _generate_mock_feed_entry("1234", "Title 1", 20.5, (-31.1, 150.1)) with patch( - "geojson_client.nsw_rural_fire_service_feed." "NswRuralFireServiceFeed" - ) as mock_feed: - mock_feed.return_value.update.return_value = "OK", [mock_entry_1] + "aio_geojson_nsw_rfs_incidents.feed_manager.NswRuralFireServiceIncidentsFeed", + wraps=NswRuralFireServiceIncidentsFeed, + ) as mock_feed_manager, patch( + "aio_geojson_client.feed.GeoJsonFeed.update" + ) as mock_feed_update: + mock_feed_update.return_value = "OK", [mock_entry_1] with assert_setup_component(1, geo_location.DOMAIN): assert await async_setup_component( @@ -238,6 +250,6 @@ async def test_setup_with_custom_location(hass): all_states = hass.states.async_all() assert len(all_states) == 1 - assert mock_feed.call_args == call( - (15.1, 25.2), filter_categories=[], filter_radius=200.0 + assert mock_feed_manager.call_args == call( + ANY, (15.1, 25.2), filter_categories=[], filter_radius=200.0 )