Use runtime_data for Swiss Public Transport (#128369)

* use runtime_data instead of hass.data[<key>]

* fix service response export type

* reduce runtime_data to be just the coordinator

* fix rebase

* fix ruff

* address reviews

* address reviews

* no general core import

* no general config_entries import

* fix also for services

* remove untyped config entry

* remove unneeded cast
This commit is contained in:
Cyrill Raccaud 2024-10-21 11:50:22 +02:00 committed by GitHub
parent 0d447c9d50
commit 110751e992
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 31 additions and 23 deletions

View File

@ -8,8 +8,8 @@ from opendata_transport.exceptions import (
OpendataTransportError,
)
from homeassistant import config_entries, core
from homeassistant.const import Platform
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryError, ConfigEntryNotReady
from homeassistant.helpers import (
config_validation as cv,
@ -20,7 +20,10 @@ from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.typing import ConfigType
from .const import CONF_DESTINATION, CONF_START, CONF_VIA, DOMAIN, PLACEHOLDERS
from .coordinator import SwissPublicTransportDataUpdateCoordinator
from .coordinator import (
SwissPublicTransportConfigEntry,
SwissPublicTransportDataUpdateCoordinator,
)
from .helper import unique_id_from_config
from .services import setup_services
@ -32,14 +35,14 @@ PLATFORMS: list[Platform] = [Platform.SENSOR]
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
async def async_setup(hass: core.HomeAssistant, config: ConfigType) -> bool:
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the Swiss public transport component."""
setup_services(hass)
return True
async def async_setup_entry(
hass: core.HomeAssistant, entry: config_entries.ConfigEntry
hass: HomeAssistant, entry: SwissPublicTransportConfigEntry
) -> bool:
"""Set up Swiss public transport from a config entry."""
config = entry.data
@ -74,24 +77,21 @@ async def async_setup_entry(
coordinator = SwissPublicTransportDataUpdateCoordinator(hass, opendata)
await coordinator.async_config_entry_first_refresh()
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = coordinator
entry.runtime_data = coordinator
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
return True
async def async_unload_entry(
hass: core.HomeAssistant, entry: config_entries.ConfigEntry
hass: HomeAssistant, entry: SwissPublicTransportConfigEntry
) -> bool:
"""Unload a config entry."""
if unload_ok := await hass.config_entries.async_unload_platforms(entry, PLATFORMS):
hass.data[DOMAIN].pop(entry.entry_id)
return unload_ok
return await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
async def async_migrate_entry(
hass: core.HomeAssistant, config_entry: config_entries.ConfigEntry
hass: HomeAssistant, config_entry: SwissPublicTransportConfigEntry
) -> bool:
"""Migrate config entry."""
_LOGGER.debug("Migrating from version %s", config_entry.version)

View File

@ -22,6 +22,10 @@ from .const import CONNECTIONS_COUNT, DEFAULT_UPDATE_TIME, DOMAIN
_LOGGER = logging.getLogger(__name__)
type SwissPublicTransportConfigEntry = ConfigEntry[
SwissPublicTransportDataUpdateCoordinator
]
class DataConnection(TypedDict):
"""A connection data class."""
@ -51,7 +55,7 @@ class SwissPublicTransportDataUpdateCoordinator(
):
"""A SwissPublicTransport Data Update Coordinator."""
config_entry: ConfigEntry
config_entry: SwissPublicTransportConfigEntry
def __init__(self, hass: HomeAssistant, opendata: OpendataTransport) -> None:
"""Initialize the SwissPublicTransport data coordinator."""

View File

@ -8,20 +8,24 @@ from datetime import datetime, timedelta
import logging
from typing import TYPE_CHECKING
from homeassistant import config_entries, core
from homeassistant.components.sensor import (
SensorDeviceClass,
SensorEntity,
SensorEntityDescription,
)
from homeassistant.const import UnitOfTime
from homeassistant.core import HomeAssistant
from homeassistant.helpers.device_registry import DeviceEntryType, DeviceInfo
from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.typing import StateType
from homeassistant.helpers.update_coordinator import CoordinatorEntity
from .const import CONNECTIONS_COUNT, DOMAIN
from .coordinator import DataConnection, SwissPublicTransportDataUpdateCoordinator
from .coordinator import (
DataConnection,
SwissPublicTransportConfigEntry,
SwissPublicTransportDataUpdateCoordinator,
)
_LOGGER = logging.getLogger(__name__)
@ -80,20 +84,18 @@ SENSORS: tuple[SwissPublicTransportSensorEntityDescription, ...] = (
async def async_setup_entry(
hass: core.HomeAssistant,
config_entry: config_entries.ConfigEntry,
hass: HomeAssistant,
config_entry: SwissPublicTransportConfigEntry,
async_add_entities: AddEntitiesCallback,
) -> None:
"""Set up the sensor from a config entry created in the integrations UI."""
coordinator = hass.data[DOMAIN][config_entry.entry_id]
unique_id = config_entry.unique_id
if TYPE_CHECKING:
assert unique_id
async_add_entities(
SwissPublicTransportSensor(coordinator, description, unique_id)
SwissPublicTransportSensor(config_entry.runtime_data, description, unique_id)
for description in SENSORS
)

View File

@ -2,7 +2,6 @@
import voluptuous as vol
from homeassistant import config_entries
from homeassistant.config_entries import ConfigEntryState
from homeassistant.core import (
HomeAssistant,
@ -26,6 +25,7 @@ from .const import (
DOMAIN,
SERVICE_FETCH_CONNECTIONS,
)
from .coordinator import SwissPublicTransportConfigEntry
SERVICE_FETCH_CONNECTIONS_SCHEMA = vol.Schema(
{
@ -41,7 +41,7 @@ SERVICE_FETCH_CONNECTIONS_SCHEMA = vol.Schema(
def async_get_entry(
hass: HomeAssistant, config_entry_id: str
) -> config_entries.ConfigEntry:
) -> SwissPublicTransportConfigEntry:
"""Get the Swiss public transport config entry."""
if not (entry := hass.config_entries.async_get_entry(config_entry_id)):
raise ServiceValidationError(
@ -66,10 +66,12 @@ def setup_services(hass: HomeAssistant) -> None:
) -> ServiceResponse:
"""Fetch a set of connections."""
config_entry = async_get_entry(hass, call.data[ATTR_CONFIG_ENTRY_ID])
limit = call.data.get(ATTR_LIMIT) or CONNECTIONS_COUNT
coordinator = hass.data[DOMAIN][config_entry.entry_id]
try:
connections = await coordinator.fetch_connections_as_json(limit=int(limit))
connections = await config_entry.runtime_data.fetch_connections_as_json(
limit=int(limit)
)
except UpdateFailed as e:
raise HomeAssistantError(
translation_domain=DOMAIN,