From 110751e9923571835bedb68e90fb30418d44e6d8 Mon Sep 17 00:00:00 2001 From: Cyrill Raccaud Date: Mon, 21 Oct 2024 11:50:22 +0200 Subject: [PATCH] Use runtime_data for Swiss Public Transport (#128369) * use runtime_data instead of hass.data[] * fix service response export type * reduce runtime_data to be just the coordinator * fix rebase * fix ruff * address reviews * address reviews * no general core import * no general config_entries import * fix also for services * remove untyped config entry * remove unneeded cast --- .../swiss_public_transport/__init__.py | 22 +++++++++---------- .../swiss_public_transport/coordinator.py | 6 ++++- .../swiss_public_transport/sensor.py | 16 ++++++++------ .../swiss_public_transport/services.py | 10 +++++---- 4 files changed, 31 insertions(+), 23 deletions(-) diff --git a/homeassistant/components/swiss_public_transport/__init__.py b/homeassistant/components/swiss_public_transport/__init__.py index dc1d0eb236c..bceac6007a2 100644 --- a/homeassistant/components/swiss_public_transport/__init__.py +++ b/homeassistant/components/swiss_public_transport/__init__.py @@ -8,8 +8,8 @@ from opendata_transport.exceptions import ( OpendataTransportError, ) -from homeassistant import config_entries, core from homeassistant.const import Platform +from homeassistant.core import HomeAssistant from homeassistant.exceptions import ConfigEntryError, ConfigEntryNotReady from homeassistant.helpers import ( config_validation as cv, @@ -20,7 +20,10 @@ from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.typing import ConfigType from .const import CONF_DESTINATION, CONF_START, CONF_VIA, DOMAIN, PLACEHOLDERS -from .coordinator import SwissPublicTransportDataUpdateCoordinator +from .coordinator import ( + SwissPublicTransportConfigEntry, + SwissPublicTransportDataUpdateCoordinator, +) from .helper import unique_id_from_config from .services import setup_services @@ -32,14 +35,14 @@ PLATFORMS: list[Platform] = [Platform.SENSOR] CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN) -async def async_setup(hass: core.HomeAssistant, config: ConfigType) -> bool: +async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up the Swiss public transport component.""" setup_services(hass) return True async def async_setup_entry( - hass: core.HomeAssistant, entry: config_entries.ConfigEntry + hass: HomeAssistant, entry: SwissPublicTransportConfigEntry ) -> bool: """Set up Swiss public transport from a config entry.""" config = entry.data @@ -74,24 +77,21 @@ async def async_setup_entry( coordinator = SwissPublicTransportDataUpdateCoordinator(hass, opendata) await coordinator.async_config_entry_first_refresh() - hass.data.setdefault(DOMAIN, {})[entry.entry_id] = coordinator + entry.runtime_data = coordinator await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) return True async def async_unload_entry( - hass: core.HomeAssistant, entry: config_entries.ConfigEntry + hass: HomeAssistant, entry: SwissPublicTransportConfigEntry ) -> bool: """Unload a config entry.""" - if unload_ok := await hass.config_entries.async_unload_platforms(entry, PLATFORMS): - hass.data[DOMAIN].pop(entry.entry_id) - - return unload_ok + return await hass.config_entries.async_unload_platforms(entry, PLATFORMS) async def async_migrate_entry( - hass: core.HomeAssistant, config_entry: config_entries.ConfigEntry + hass: HomeAssistant, config_entry: SwissPublicTransportConfigEntry ) -> bool: """Migrate config entry.""" _LOGGER.debug("Migrating from version %s", config_entry.version) diff --git a/homeassistant/components/swiss_public_transport/coordinator.py b/homeassistant/components/swiss_public_transport/coordinator.py index 5d51175fb26..ff14e81a44e 100644 --- a/homeassistant/components/swiss_public_transport/coordinator.py +++ b/homeassistant/components/swiss_public_transport/coordinator.py @@ -22,6 +22,10 @@ from .const import CONNECTIONS_COUNT, DEFAULT_UPDATE_TIME, DOMAIN _LOGGER = logging.getLogger(__name__) +type SwissPublicTransportConfigEntry = ConfigEntry[ + SwissPublicTransportDataUpdateCoordinator +] + class DataConnection(TypedDict): """A connection data class.""" @@ -51,7 +55,7 @@ class SwissPublicTransportDataUpdateCoordinator( ): """A SwissPublicTransport Data Update Coordinator.""" - config_entry: ConfigEntry + config_entry: SwissPublicTransportConfigEntry def __init__(self, hass: HomeAssistant, opendata: OpendataTransport) -> None: """Initialize the SwissPublicTransport data coordinator.""" diff --git a/homeassistant/components/swiss_public_transport/sensor.py b/homeassistant/components/swiss_public_transport/sensor.py index eb73ce03062..452ec31972f 100644 --- a/homeassistant/components/swiss_public_transport/sensor.py +++ b/homeassistant/components/swiss_public_transport/sensor.py @@ -8,20 +8,24 @@ from datetime import datetime, timedelta import logging from typing import TYPE_CHECKING -from homeassistant import config_entries, core from homeassistant.components.sensor import ( SensorDeviceClass, SensorEntity, SensorEntityDescription, ) from homeassistant.const import UnitOfTime +from homeassistant.core import HomeAssistant from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.typing import StateType from homeassistant.helpers.update_coordinator import CoordinatorEntity from .const import CONNECTIONS_COUNT, DOMAIN -from .coordinator import DataConnection, SwissPublicTransportDataUpdateCoordinator +from .coordinator import ( + DataConnection, + SwissPublicTransportConfigEntry, + SwissPublicTransportDataUpdateCoordinator, +) _LOGGER = logging.getLogger(__name__) @@ -80,20 +84,18 @@ SENSORS: tuple[SwissPublicTransportSensorEntityDescription, ...] = ( async def async_setup_entry( - hass: core.HomeAssistant, - config_entry: config_entries.ConfigEntry, + hass: HomeAssistant, + config_entry: SwissPublicTransportConfigEntry, async_add_entities: AddEntitiesCallback, ) -> None: """Set up the sensor from a config entry created in the integrations UI.""" - coordinator = hass.data[DOMAIN][config_entry.entry_id] - unique_id = config_entry.unique_id if TYPE_CHECKING: assert unique_id async_add_entities( - SwissPublicTransportSensor(coordinator, description, unique_id) + SwissPublicTransportSensor(config_entry.runtime_data, description, unique_id) for description in SENSORS ) diff --git a/homeassistant/components/swiss_public_transport/services.py b/homeassistant/components/swiss_public_transport/services.py index 4ede91e6c42..3abf1a14b9f 100644 --- a/homeassistant/components/swiss_public_transport/services.py +++ b/homeassistant/components/swiss_public_transport/services.py @@ -2,7 +2,6 @@ import voluptuous as vol -from homeassistant import config_entries from homeassistant.config_entries import ConfigEntryState from homeassistant.core import ( HomeAssistant, @@ -26,6 +25,7 @@ from .const import ( DOMAIN, SERVICE_FETCH_CONNECTIONS, ) +from .coordinator import SwissPublicTransportConfigEntry SERVICE_FETCH_CONNECTIONS_SCHEMA = vol.Schema( { @@ -41,7 +41,7 @@ SERVICE_FETCH_CONNECTIONS_SCHEMA = vol.Schema( def async_get_entry( hass: HomeAssistant, config_entry_id: str -) -> config_entries.ConfigEntry: +) -> SwissPublicTransportConfigEntry: """Get the Swiss public transport config entry.""" if not (entry := hass.config_entries.async_get_entry(config_entry_id)): raise ServiceValidationError( @@ -66,10 +66,12 @@ def setup_services(hass: HomeAssistant) -> None: ) -> ServiceResponse: """Fetch a set of connections.""" config_entry = async_get_entry(hass, call.data[ATTR_CONFIG_ENTRY_ID]) + limit = call.data.get(ATTR_LIMIT) or CONNECTIONS_COUNT - coordinator = hass.data[DOMAIN][config_entry.entry_id] try: - connections = await coordinator.fetch_connections_as_json(limit=int(limit)) + connections = await config_entry.runtime_data.fetch_connections_as_json( + limit=int(limit) + ) except UpdateFailed as e: raise HomeAssistantError( translation_domain=DOMAIN,