Use runtime_data for Swiss Public Transport (#128369)

* use runtime_data instead of hass.data[<key>]

* fix service response export type

* reduce runtime_data to be just the coordinator

* fix rebase

* fix ruff

* address reviews

* address reviews

* no general core import

* no general config_entries import

* fix also for services

* remove untyped config entry

* remove unneeded cast
This commit is contained in:
Cyrill Raccaud 2024-10-21 11:50:22 +02:00 committed by GitHub
parent 0d447c9d50
commit 110751e992
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 31 additions and 23 deletions

View File

@ -8,8 +8,8 @@ from opendata_transport.exceptions import (
OpendataTransportError, OpendataTransportError,
) )
from homeassistant import config_entries, core
from homeassistant.const import Platform from homeassistant.const import Platform
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryError, ConfigEntryNotReady from homeassistant.exceptions import ConfigEntryError, ConfigEntryNotReady
from homeassistant.helpers import ( from homeassistant.helpers import (
config_validation as cv, config_validation as cv,
@ -20,7 +20,10 @@ from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from .const import CONF_DESTINATION, CONF_START, CONF_VIA, DOMAIN, PLACEHOLDERS from .const import CONF_DESTINATION, CONF_START, CONF_VIA, DOMAIN, PLACEHOLDERS
from .coordinator import SwissPublicTransportDataUpdateCoordinator from .coordinator import (
SwissPublicTransportConfigEntry,
SwissPublicTransportDataUpdateCoordinator,
)
from .helper import unique_id_from_config from .helper import unique_id_from_config
from .services import setup_services from .services import setup_services
@ -32,14 +35,14 @@ PLATFORMS: list[Platform] = [Platform.SENSOR]
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN) CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
async def async_setup(hass: core.HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the Swiss public transport component.""" """Set up the Swiss public transport component."""
setup_services(hass) setup_services(hass)
return True return True
async def async_setup_entry( async def async_setup_entry(
hass: core.HomeAssistant, entry: config_entries.ConfigEntry hass: HomeAssistant, entry: SwissPublicTransportConfigEntry
) -> bool: ) -> bool:
"""Set up Swiss public transport from a config entry.""" """Set up Swiss public transport from a config entry."""
config = entry.data config = entry.data
@ -74,24 +77,21 @@ async def async_setup_entry(
coordinator = SwissPublicTransportDataUpdateCoordinator(hass, opendata) coordinator = SwissPublicTransportDataUpdateCoordinator(hass, opendata)
await coordinator.async_config_entry_first_refresh() await coordinator.async_config_entry_first_refresh()
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = coordinator entry.runtime_data = coordinator
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
return True return True
async def async_unload_entry( async def async_unload_entry(
hass: core.HomeAssistant, entry: config_entries.ConfigEntry hass: HomeAssistant, entry: SwissPublicTransportConfigEntry
) -> bool: ) -> bool:
"""Unload a config entry.""" """Unload a config entry."""
if unload_ok := await hass.config_entries.async_unload_platforms(entry, PLATFORMS): return await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
hass.data[DOMAIN].pop(entry.entry_id)
return unload_ok
async def async_migrate_entry( async def async_migrate_entry(
hass: core.HomeAssistant, config_entry: config_entries.ConfigEntry hass: HomeAssistant, config_entry: SwissPublicTransportConfigEntry
) -> bool: ) -> bool:
"""Migrate config entry.""" """Migrate config entry."""
_LOGGER.debug("Migrating from version %s", config_entry.version) _LOGGER.debug("Migrating from version %s", config_entry.version)

View File

@ -22,6 +22,10 @@ from .const import CONNECTIONS_COUNT, DEFAULT_UPDATE_TIME, DOMAIN
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
type SwissPublicTransportConfigEntry = ConfigEntry[
SwissPublicTransportDataUpdateCoordinator
]
class DataConnection(TypedDict): class DataConnection(TypedDict):
"""A connection data class.""" """A connection data class."""
@ -51,7 +55,7 @@ class SwissPublicTransportDataUpdateCoordinator(
): ):
"""A SwissPublicTransport Data Update Coordinator.""" """A SwissPublicTransport Data Update Coordinator."""
config_entry: ConfigEntry config_entry: SwissPublicTransportConfigEntry
def __init__(self, hass: HomeAssistant, opendata: OpendataTransport) -> None: def __init__(self, hass: HomeAssistant, opendata: OpendataTransport) -> None:
"""Initialize the SwissPublicTransport data coordinator.""" """Initialize the SwissPublicTransport data coordinator."""

View File

@ -8,20 +8,24 @@ from datetime import datetime, timedelta
import logging import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from homeassistant import config_entries, core
from homeassistant.components.sensor import ( from homeassistant.components.sensor import (
SensorDeviceClass, SensorDeviceClass,
SensorEntity, SensorEntity,
SensorEntityDescription, SensorEntityDescription,
) )
from homeassistant.const import UnitOfTime from homeassistant.const import UnitOfTime
from homeassistant.core import HomeAssistant
from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo
from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.typing import StateType from homeassistant.helpers.typing import StateType
from homeassistant.helpers.update_coordinator import CoordinatorEntity from homeassistant.helpers.update_coordinator import CoordinatorEntity
from .const import CONNECTIONS_COUNT, DOMAIN from .const import CONNECTIONS_COUNT, DOMAIN
from .coordinator import DataConnection, SwissPublicTransportDataUpdateCoordinator from .coordinator import (
DataConnection,
SwissPublicTransportConfigEntry,
SwissPublicTransportDataUpdateCoordinator,
)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -80,20 +84,18 @@ SENSORS: tuple[SwissPublicTransportSensorEntityDescription, ...] = (
async def async_setup_entry( async def async_setup_entry(
hass: core.HomeAssistant, hass: HomeAssistant,
config_entry: config_entries.ConfigEntry, config_entry: SwissPublicTransportConfigEntry,
async_add_entities: AddEntitiesCallback, async_add_entities: AddEntitiesCallback,
) -> None: ) -> None:
"""Set up the sensor from a config entry created in the integrations UI.""" """Set up the sensor from a config entry created in the integrations UI."""
coordinator = hass.data[DOMAIN][config_entry.entry_id]
unique_id = config_entry.unique_id unique_id = config_entry.unique_id
if TYPE_CHECKING: if TYPE_CHECKING:
assert unique_id assert unique_id
async_add_entities( async_add_entities(
SwissPublicTransportSensor(coordinator, description, unique_id) SwissPublicTransportSensor(config_entry.runtime_data, description, unique_id)
for description in SENSORS for description in SENSORS
) )

View File

@ -2,7 +2,6 @@
import voluptuous as vol import voluptuous as vol
from homeassistant import config_entries
from homeassistant.config_entries import ConfigEntryState from homeassistant.config_entries import ConfigEntryState
from homeassistant.core import ( from homeassistant.core import (
HomeAssistant, HomeAssistant,
@ -26,6 +25,7 @@ from .const import (
DOMAIN, DOMAIN,
SERVICE_FETCH_CONNECTIONS, SERVICE_FETCH_CONNECTIONS,
) )
from .coordinator import SwissPublicTransportConfigEntry
SERVICE_FETCH_CONNECTIONS_SCHEMA = vol.Schema( SERVICE_FETCH_CONNECTIONS_SCHEMA = vol.Schema(
{ {
@ -41,7 +41,7 @@ SERVICE_FETCH_CONNECTIONS_SCHEMA = vol.Schema(
def async_get_entry( def async_get_entry(
hass: HomeAssistant, config_entry_id: str hass: HomeAssistant, config_entry_id: str
) -> config_entries.ConfigEntry: ) -> SwissPublicTransportConfigEntry:
"""Get the Swiss public transport config entry.""" """Get the Swiss public transport config entry."""
if not (entry := hass.config_entries.async_get_entry(config_entry_id)): if not (entry := hass.config_entries.async_get_entry(config_entry_id)):
raise ServiceValidationError( raise ServiceValidationError(
@ -66,10 +66,12 @@ def setup_services(hass: HomeAssistant) -> None:
) -> ServiceResponse: ) -> ServiceResponse:
"""Fetch a set of connections.""" """Fetch a set of connections."""
config_entry = async_get_entry(hass, call.data[ATTR_CONFIG_ENTRY_ID]) config_entry = async_get_entry(hass, call.data[ATTR_CONFIG_ENTRY_ID])
limit = call.data.get(ATTR_LIMIT) or CONNECTIONS_COUNT limit = call.data.get(ATTR_LIMIT) or CONNECTIONS_COUNT
coordinator = hass.data[DOMAIN][config_entry.entry_id]
try: try:
connections = await coordinator.fetch_connections_as_json(limit=int(limit)) connections = await config_entry.runtime_data.fetch_connections_as_json(
limit=int(limit)
)
except UpdateFailed as e: except UpdateFailed as e:
raise HomeAssistantError( raise HomeAssistantError(
translation_domain=DOMAIN, translation_domain=DOMAIN,