Explicitly pass in the config_entry in israel_rail coordinator (#138132)

explicitly pass in the config_entry in coordinator
This commit is contained in:
Michael 2025-02-09 20:58:09 +01:00 committed by GitHub
parent 52363d5369
commit 733d9de042
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 13 additions and 9 deletions

View File

@ -4,13 +4,12 @@ import logging
from israelrailapi import TrainSchedule from israelrailapi import TrainSchedule
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import Platform from homeassistant.const import Platform
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryNotReady from homeassistant.exceptions import ConfigEntryNotReady
from .const import CONF_DESTINATION, CONF_START, DOMAIN from .const import CONF_DESTINATION, CONF_START, DOMAIN
from .coordinator import IsraelRailDataUpdateCoordinator from .coordinator import IsraelRailConfigEntry, IsraelRailDataUpdateCoordinator
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -18,9 +17,6 @@ _LOGGER = logging.getLogger(__name__)
PLATFORMS: list[Platform] = [Platform.SENSOR] PLATFORMS: list[Platform] = [Platform.SENSOR]
type IsraelRailConfigEntry = ConfigEntry[IsraelRailDataUpdateCoordinator]
async def async_setup_entry(hass: HomeAssistant, entry: IsraelRailConfigEntry) -> bool: async def async_setup_entry(hass: HomeAssistant, entry: IsraelRailConfigEntry) -> bool:
"""Set up Israel rail from a config entry.""" """Set up Israel rail from a config entry."""
config = entry.data config = entry.data
@ -43,7 +39,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: IsraelRailConfigEntry) -
) from e ) from e
israel_rail_coordinator = IsraelRailDataUpdateCoordinator( israel_rail_coordinator = IsraelRailDataUpdateCoordinator(
hass, train_schedule, start, destination hass, entry, train_schedule, start, destination
) )
await israel_rail_coordinator.async_config_entry_first_refresh() await israel_rail_coordinator.async_config_entry_first_refresh()
entry.runtime_data = israel_rail_coordinator entry.runtime_data = israel_rail_coordinator

View File

@ -38,14 +38,18 @@ def departure_time(train_route: TrainRoute) -> datetime | None:
return start_datetime.astimezone() if start_datetime else None return start_datetime.astimezone() if start_datetime else None
type IsraelRailConfigEntry = ConfigEntry[IsraelRailDataUpdateCoordinator]
class IsraelRailDataUpdateCoordinator(DataUpdateCoordinator[list[DataConnection]]): class IsraelRailDataUpdateCoordinator(DataUpdateCoordinator[list[DataConnection]]):
"""A IsraelRail Data Update Coordinator.""" """A IsraelRail Data Update Coordinator."""
config_entry: ConfigEntry config_entry: IsraelRailConfigEntry
def __init__( def __init__(
self, self,
hass: HomeAssistant, hass: HomeAssistant,
config_entry: IsraelRailConfigEntry,
train_schedule: TrainSchedule, train_schedule: TrainSchedule,
start: str, start: str,
destination: str, destination: str,
@ -54,6 +58,7 @@ class IsraelRailDataUpdateCoordinator(DataUpdateCoordinator[list[DataConnection]
super().__init__( super().__init__(
hass, hass,
_LOGGER, _LOGGER,
config_entry=config_entry,
name=DOMAIN, name=DOMAIN,
update_interval=DEFAULT_SCAN_INTERVAL, update_interval=DEFAULT_SCAN_INTERVAL,
) )

View File

@ -19,9 +19,12 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.typing import StateType from homeassistant.helpers.typing import StateType
from homeassistant.helpers.update_coordinator import CoordinatorEntity from homeassistant.helpers.update_coordinator import CoordinatorEntity
from . import IsraelRailConfigEntry
from .const import ATTRIBUTION, DEPARTURES_COUNT, DOMAIN from .const import ATTRIBUTION, DEPARTURES_COUNT, DOMAIN
from .coordinator import DataConnection, IsraelRailDataUpdateCoordinator from .coordinator import (
DataConnection,
IsraelRailConfigEntry,
IsraelRailDataUpdateCoordinator,
)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)