From e652d37f2944528a289eaadd9cbe18fd7019cd06 Mon Sep 17 00:00:00 2001 From: Ian Date: Mon, 2 Oct 2023 01:56:10 -0700 Subject: [PATCH] Use data update coordinator in NextBus to reduce api calls (#100602) --- homeassistant/components/nextbus/__init__.py | 28 ++++- .../components/nextbus/coordinator.py | 78 +++++++++++++ .../components/nextbus/manifest.json | 2 +- homeassistant/components/nextbus/sensor.py | 104 ++++++++++-------- requirements_all.txt | 2 +- requirements_test_all.txt | 2 +- tests/components/nextbus/test_sensor.py | 38 ++++++- 7 files changed, 200 insertions(+), 54 deletions(-) create mode 100644 homeassistant/components/nextbus/coordinator.py diff --git a/homeassistant/components/nextbus/__init__.py b/homeassistant/components/nextbus/__init__.py index b582f82b929..e1f4dcc2840 100644 --- a/homeassistant/components/nextbus/__init__.py +++ b/homeassistant/components/nextbus/__init__.py @@ -4,15 +4,41 @@ from homeassistant.config_entries import ConfigEntry from homeassistant.const import Platform from homeassistant.core import HomeAssistant +from .const import CONF_AGENCY, CONF_ROUTE, CONF_STOP, DOMAIN +from .coordinator import NextBusDataUpdateCoordinator + PLATFORMS = [Platform.SENSOR] async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Set up platforms for NextBus.""" + entry_agency = entry.data[CONF_AGENCY] + + coordinator: NextBusDataUpdateCoordinator = hass.data.setdefault(DOMAIN, {}).get( + entry_agency + ) + if coordinator is None: + coordinator = NextBusDataUpdateCoordinator(hass, entry_agency) + hass.data[DOMAIN][entry_agency] = coordinator + + coordinator.add_stop_route(entry.data[CONF_STOP], entry.data[CONF_ROUTE]) + + await coordinator.async_config_entry_first_refresh() + await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) + return True async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Unload a config entry.""" - return await hass.config_entries.async_unload_platforms(entry, PLATFORMS) + if await hass.config_entries.async_unload_platforms(entry, PLATFORMS): + entry_agency = entry.data.get(CONF_AGENCY) + coordinator: NextBusDataUpdateCoordinator = hass.data[DOMAIN][entry_agency] + coordinator.remove_stop_route(entry.data[CONF_STOP], entry.data[CONF_ROUTE]) + if not coordinator.has_routes(): + hass.data[DOMAIN].pop(entry_agency) + + return True + + return False diff --git a/homeassistant/components/nextbus/coordinator.py b/homeassistant/components/nextbus/coordinator.py new file mode 100644 index 00000000000..f130e40ef05 --- /dev/null +++ b/homeassistant/components/nextbus/coordinator.py @@ -0,0 +1,78 @@ +"""NextBus data update coordinator.""" +from datetime import timedelta +import logging +from typing import Any, cast + +from py_nextbus import NextBusClient +from py_nextbus.client import NextBusFormatError, NextBusHTTPError, RouteStop + +from homeassistant.core import HomeAssistant +from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed + +from .const import DOMAIN +from .util import listify + +_LOGGER = logging.getLogger(__name__) + + +class NextBusDataUpdateCoordinator(DataUpdateCoordinator): + """Class to manage fetching NextBus data.""" + + def __init__(self, hass: HomeAssistant, agency: str) -> None: + """Initialize a global coordinator for fetching data for a given agency.""" + super().__init__( + hass, + _LOGGER, + name=DOMAIN, + update_interval=timedelta(seconds=30), + ) + self.client = NextBusClient(output_format="json", agency=agency) + self._agency = agency + self._stop_routes: set[RouteStop] = set() + self._predictions: dict[RouteStop, dict[str, Any]] = {} + + def add_stop_route(self, stop_tag: str, route_tag: str) -> None: + """Tell coordinator to start tracking a given stop and route.""" + self._stop_routes.add(RouteStop(route_tag, stop_tag)) + + def remove_stop_route(self, stop_tag: str, route_tag: str) -> None: + """Tell coordinator to stop tracking a given stop and route.""" + self._stop_routes.remove(RouteStop(route_tag, stop_tag)) + + def get_prediction_data( + self, stop_tag: str, route_tag: str + ) -> dict[str, Any] | None: + """Get prediction result for a given stop and route.""" + return self._predictions.get(RouteStop(route_tag, stop_tag)) + + def _calc_predictions(self, data: dict[str, Any]) -> None: + self._predictions = { + RouteStop(prediction["routeTag"], prediction["stopTag"]): prediction + for prediction in listify(data.get("predictions", [])) + } + + def get_attribution(self) -> str | None: + """Get attribution from api results.""" + return self.data.get("copyright") + + def has_routes(self) -> bool: + """Check if this coordinator is tracking any routes.""" + return len(self._stop_routes) > 0 + + async def _async_update_data(self) -> dict[str, Any]: + """Fetch data from NextBus.""" + self.logger.debug("Updating data from API. Routes: %s", str(self._stop_routes)) + + def _update_data() -> dict: + """Fetch data from NextBus.""" + self.logger.debug("Updating data from API (executor)") + try: + data = self.client.get_predictions_for_multi_stops(self._stop_routes) + # Casting here because we expect dict and not a str due to the input format selected being JSON + data = cast(dict[str, Any], data) + self._calc_predictions(data) + return data + except (NextBusHTTPError, NextBusFormatError) as ex: + raise UpdateFailed("Failed updating nextbus data", ex) from ex + + return await self.hass.async_add_executor_job(_update_data) diff --git a/homeassistant/components/nextbus/manifest.json b/homeassistant/components/nextbus/manifest.json index 15eb9b4e245..9d1490a4ae6 100644 --- a/homeassistant/components/nextbus/manifest.json +++ b/homeassistant/components/nextbus/manifest.json @@ -6,5 +6,5 @@ "documentation": "https://www.home-assistant.io/integrations/nextbus", "iot_class": "cloud_polling", "loggers": ["py_nextbus"], - "requirements": ["py-nextbusnext==0.1.5"] + "requirements": ["py-nextbusnext==1.0.0"] } diff --git a/homeassistant/components/nextbus/sensor.py b/homeassistant/components/nextbus/sensor.py index 1582ec25ffe..6ef647f98ad 100644 --- a/homeassistant/components/nextbus/sensor.py +++ b/homeassistant/components/nextbus/sensor.py @@ -3,8 +3,8 @@ from __future__ import annotations from itertools import chain import logging +from typing import cast -from py_nextbus import NextBusClient import voluptuous as vol from homeassistant.components.sensor import ( @@ -14,14 +14,16 @@ from homeassistant.components.sensor import ( ) from homeassistant.config_entries import SOURCE_IMPORT, ConfigEntry from homeassistant.const import CONF_NAME -from homeassistant.core import DOMAIN as HOMEASSISTANT_DOMAIN, HomeAssistant +from homeassistant.core import DOMAIN as HOMEASSISTANT_DOMAIN, HomeAssistant, callback import homeassistant.helpers.config_validation as cv from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.issue_registry import IssueSeverity, async_create_issue from homeassistant.helpers.typing import ConfigType, DiscoveryInfoType +from homeassistant.helpers.update_coordinator import CoordinatorEntity from homeassistant.util.dt import utc_from_timestamp from .const import CONF_AGENCY, CONF_ROUTE, CONF_STOP, DOMAIN +from .coordinator import NextBusDataUpdateCoordinator from .util import listify, maybe_first _LOGGER = logging.getLogger(__name__) @@ -70,23 +72,28 @@ async def async_setup_entry( async_add_entities: AddEntitiesCallback, ) -> None: """Load values from configuration and initialize the platform.""" - client = NextBusClient(output_format="json") - _LOGGER.debug(config.data) + entry_agency = config.data[CONF_AGENCY] - sensor = NextBusDepartureSensor( - client, - config.unique_id, - config.data[CONF_AGENCY], - config.data[CONF_ROUTE], - config.data[CONF_STOP], - config.data.get(CONF_NAME) or config.title, + coordinator: NextBusDataUpdateCoordinator = hass.data[DOMAIN].get(entry_agency) + + async_add_entities( + ( + NextBusDepartureSensor( + coordinator, + cast(str, config.unique_id), + config.data[CONF_AGENCY], + config.data[CONF_ROUTE], + config.data[CONF_STOP], + config.data.get(CONF_NAME) or config.title, + ), + ), ) - async_add_entities((sensor,), True) - -class NextBusDepartureSensor(SensorEntity): +class NextBusDepartureSensor( + CoordinatorEntity[NextBusDataUpdateCoordinator], SensorEntity +): """Sensor class that displays upcoming NextBus times. To function, this requires knowing the agency tag as well as the tags for @@ -100,49 +107,57 @@ class NextBusDepartureSensor(SensorEntity): _attr_device_class = SensorDeviceClass.TIMESTAMP _attr_icon = "mdi:bus" - def __init__(self, client, unique_id, agency, route, stop, name): + def __init__( + self, + coordinator: NextBusDataUpdateCoordinator, + unique_id: str, + agency: str, + route: str, + stop: str, + name: str, + ) -> None: """Initialize sensor with all required config.""" + super().__init__(coordinator) self.agency = agency self.route = route self.stop = stop - self._attr_extra_state_attributes = {} + self._attr_extra_state_attributes: dict[str, str] = {} self._attr_unique_id = unique_id self._attr_name = name - self._client = client - def _log_debug(self, message, *args): """Log debug message with prefix.""" _LOGGER.debug(":".join((self.agency, self.route, self.stop, message)), *args) - def update(self) -> None: + def _log_err(self, message, *args): + """Log error message with prefix.""" + _LOGGER.error(":".join((self.agency, self.route, self.stop, message)), *args) + + async def async_added_to_hass(self) -> None: + """Read data from coordinator after adding to hass.""" + self._handle_coordinator_update() + await super().async_added_to_hass() + + @callback + def _handle_coordinator_update(self) -> None: """Update sensor with new departures times.""" - # Note: using Multi because there is a bug with the single stop impl - results = self._client.get_predictions_for_multi_stops( - [{"stop_tag": self.stop, "route_tag": self.route}], self.agency - ) + results = self.coordinator.get_prediction_data(self.stop, self.route) + self._attr_attribution = self.coordinator.get_attribution() self._log_debug("Predictions results: %s", results) - self._attr_attribution = results.get("copyright") - if "Error" in results: - self._log_debug("Could not get predictions: %s", results) - - if not results.get("predictions"): - self._log_debug("No predictions available") + if not results or "Error" in results: + self._log_err("Error getting predictions: %s", str(results)) self._attr_native_value = None - # Remove attributes that may now be outdated self._attr_extra_state_attributes.pop("upcoming", None) return - results = results["predictions"] - # Set detailed attributes self._attr_extra_state_attributes.update( { - "agency": results.get("agencyTitle"), - "route": results.get("routeTitle"), - "stop": results.get("stopTitle"), + "agency": str(results.get("agencyTitle")), + "route": str(results.get("routeTitle")), + "stop": str(results.get("stopTitle")), } ) @@ -171,14 +186,15 @@ class NextBusDepartureSensor(SensorEntity): self._log_debug("No upcoming predictions available") self._attr_native_value = None self._attr_extra_state_attributes["upcoming"] = "No upcoming predictions" - return + else: + # Generate list of upcoming times + self._attr_extra_state_attributes["upcoming"] = ", ".join( + sorted((p["minutes"] for p in predictions), key=int) + ) - # Generate list of upcoming times - self._attr_extra_state_attributes["upcoming"] = ", ".join( - sorted((p["minutes"] for p in predictions), key=int) - ) + latest_prediction = maybe_first(predictions) + self._attr_native_value = utc_from_timestamp( + int(latest_prediction["epochTime"]) / 1000 + ) - latest_prediction = maybe_first(predictions) - self._attr_native_value = utc_from_timestamp( - int(latest_prediction["epochTime"]) / 1000 - ) + self.async_write_ha_state() diff --git a/requirements_all.txt b/requirements_all.txt index c5b3316072b..f89f428beaa 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -1508,7 +1508,7 @@ py-dormakaba-dkey==1.0.5 py-melissa-climate==2.1.4 # homeassistant.components.nextbus -py-nextbusnext==0.1.5 +py-nextbusnext==1.0.0 # homeassistant.components.nightscout py-nightscout==1.2.2 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index 3338def76f7..03c4981188d 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -1153,7 +1153,7 @@ py-dormakaba-dkey==1.0.5 py-melissa-climate==2.1.4 # homeassistant.components.nextbus -py-nextbusnext==0.1.5 +py-nextbusnext==1.0.0 # homeassistant.components.nightscout py-nightscout==1.2.2 diff --git a/tests/components/nextbus/test_sensor.py b/tests/components/nextbus/test_sensor.py index 071dd95fe7b..a4d04997e15 100644 --- a/tests/components/nextbus/test_sensor.py +++ b/tests/components/nextbus/test_sensor.py @@ -2,7 +2,9 @@ from collections.abc import Generator from copy import deepcopy from unittest.mock import MagicMock, patch +from urllib.error import HTTPError +from py_nextbus.client import NextBusFormatError, NextBusHTTPError, RouteStop import pytest from homeassistant.components import sensor @@ -12,10 +14,12 @@ from homeassistant.components.nextbus.const import ( CONF_STOP, DOMAIN, ) +from homeassistant.components.nextbus.coordinator import NextBusDataUpdateCoordinator from homeassistant.config_entries import ConfigEntryState from homeassistant.const import CONF_NAME from homeassistant.core import DOMAIN as HOMEASSISTANT_DOMAIN, HomeAssistant -from homeassistant.helpers import issue_registry as ir +from homeassistant.helpers import entity_registry as er, issue_registry as ir +from homeassistant.helpers.update_coordinator import UpdateFailed from homeassistant.setup import async_setup_component from tests.common import MockConfigEntry @@ -70,9 +74,7 @@ BASIC_RESULTS = { @pytest.fixture def mock_nextbus() -> Generator[MagicMock, None, None]: """Create a mock py_nextbus module.""" - with patch( - "homeassistant.components.nextbus.sensor.NextBusClient", - ) as client: + with patch("homeassistant.components.nextbus.coordinator.NextBusClient") as client: yield client @@ -89,7 +91,7 @@ def mock_nextbus_predictions( async def assert_setup_sensor( hass: HomeAssistant, - config: dict[str, str], + config: dict[str, dict[str, str]], expected_state=ConfigEntryState.LOADED, ) -> MockConfigEntry: """Set up the sensor and assert it's been created.""" @@ -144,9 +146,11 @@ async def test_verify_valid_state( ) -> None: """Verify all attributes are set from a valid response.""" await assert_setup_sensor(hass, CONFIG_BASIC) + entity = er.async_get(hass).async_get(SENSOR_ID) + assert entity mock_nextbus_predictions.assert_called_once_with( - [{"stop_tag": VALID_STOP, "route_tag": VALID_ROUTE}], VALID_AGENCY + {RouteStop(VALID_ROUTE, VALID_STOP)} ) state = hass.states.get(SENSOR_ID) @@ -272,6 +276,28 @@ async def test_direction_list( assert state.attributes["upcoming"] == "0, 1, 2, 3" +@pytest.mark.parametrize( + "client_exception", + ( + NextBusHTTPError("failed", HTTPError("url", 500, "error", MagicMock(), None)), + NextBusFormatError("failed"), + ), +) +async def test_prediction_exceptions( + hass: HomeAssistant, + mock_nextbus: MagicMock, + mock_nextbus_lists: MagicMock, + mock_nextbus_predictions: MagicMock, + client_exception: Exception, +) -> None: + """Test that some coodinator exceptions raise UpdateFailed exceptions.""" + await assert_setup_sensor(hass, CONFIG_BASIC) + coordinator: NextBusDataUpdateCoordinator = hass.data[DOMAIN][VALID_AGENCY] + mock_nextbus_predictions.side_effect = client_exception + with pytest.raises(UpdateFailed): + await coordinator._async_update_data() + + async def test_custom_name( hass: HomeAssistant, mock_nextbus: MagicMock,