Use setup method in coordinator for Trafikverket Train (#123138)

* Use setup method in coordinator for Trafikverket Train

* Overwrite types
This commit is contained in:
G Johansson 2024-08-11 14:15:20 +02:00 committed by GitHub
parent 9be8616cc0
commit e93d0dfdfc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 31 additions and 41 deletions

View File

@ -2,21 +2,11 @@
from __future__ import annotations from __future__ import annotations
from pytrafikverket import TrafikverketTrain
from pytrafikverket.exceptions import (
InvalidAuthentication,
MultipleTrainStationsFound,
NoTrainStationFound,
)
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_API_KEY
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady
from homeassistant.helpers import entity_registry as er from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from .const import CONF_FROM, CONF_TO, PLATFORMS from .const import PLATFORMS
from .coordinator import TVDataUpdateCoordinator from .coordinator import TVDataUpdateCoordinator
TVTrainConfigEntry = ConfigEntry[TVDataUpdateCoordinator] TVTrainConfigEntry = ConfigEntry[TVDataUpdateCoordinator]
@ -25,21 +15,7 @@ TVTrainConfigEntry = ConfigEntry[TVDataUpdateCoordinator]
async def async_setup_entry(hass: HomeAssistant, entry: TVTrainConfigEntry) -> bool: async def async_setup_entry(hass: HomeAssistant, entry: TVTrainConfigEntry) -> bool:
"""Set up Trafikverket Train from a config entry.""" """Set up Trafikverket Train from a config entry."""
http_session = async_get_clientsession(hass) coordinator = TVDataUpdateCoordinator(hass)
train_api = TrafikverketTrain(http_session, entry.data[CONF_API_KEY])
try:
to_station = await train_api.async_get_train_station(entry.data[CONF_TO])
from_station = await train_api.async_get_train_station(entry.data[CONF_FROM])
except InvalidAuthentication as error:
raise ConfigEntryAuthFailed from error
except (NoTrainStationFound, MultipleTrainStationsFound) as error:
raise ConfigEntryNotReady(
f"Problem when trying station {entry.data[CONF_FROM]} to"
f" {entry.data[CONF_TO]}. Error: {error} "
) from error
coordinator = TVDataUpdateCoordinator(hass, to_station, from_station)
await coordinator.async_config_entry_first_refresh() await coordinator.async_config_entry_first_refresh()
entry.runtime_data = coordinator entry.runtime_data = coordinator

View File

@ -10,7 +10,9 @@ from typing import TYPE_CHECKING
from pytrafikverket import TrafikverketTrain from pytrafikverket import TrafikverketTrain
from pytrafikverket.exceptions import ( from pytrafikverket.exceptions import (
InvalidAuthentication, InvalidAuthentication,
MultipleTrainStationsFound,
NoTrainAnnouncementFound, NoTrainAnnouncementFound,
NoTrainStationFound,
UnknownError, UnknownError,
) )
from pytrafikverket.models import StationInfoModel, TrainStopModel from pytrafikverket.models import StationInfoModel, TrainStopModel
@ -22,7 +24,7 @@ from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
from .const import CONF_FILTER_PRODUCT, CONF_TIME, DOMAIN from .const import CONF_FILTER_PRODUCT, CONF_FROM, CONF_TIME, CONF_TO, DOMAIN
from .util import next_departuredate from .util import next_departuredate
if TYPE_CHECKING: if TYPE_CHECKING:
@ -69,13 +71,10 @@ class TVDataUpdateCoordinator(DataUpdateCoordinator[TrainData]):
"""A Trafikverket Data Update Coordinator.""" """A Trafikverket Data Update Coordinator."""
config_entry: TVTrainConfigEntry config_entry: TVTrainConfigEntry
from_station: StationInfoModel
to_station: StationInfoModel
def __init__( def __init__(self, hass: HomeAssistant) -> None:
self,
hass: HomeAssistant,
to_station: StationInfoModel,
from_station: StationInfoModel,
) -> None:
"""Initialize the Trafikverket coordinator.""" """Initialize the Trafikverket coordinator."""
super().__init__( super().__init__(
hass, hass,
@ -86,14 +85,29 @@ class TVDataUpdateCoordinator(DataUpdateCoordinator[TrainData]):
self._train_api = TrafikverketTrain( self._train_api = TrafikverketTrain(
async_get_clientsession(hass), self.config_entry.data[CONF_API_KEY] async_get_clientsession(hass), self.config_entry.data[CONF_API_KEY]
) )
self.from_station: StationInfoModel = from_station
self.to_station: StationInfoModel = to_station
self._time: time | None = dt_util.parse_time(self.config_entry.data[CONF_TIME]) self._time: time | None = dt_util.parse_time(self.config_entry.data[CONF_TIME])
self._weekdays: list[str] = self.config_entry.data[CONF_WEEKDAY] self._weekdays: list[str] = self.config_entry.data[CONF_WEEKDAY]
self._filter_product: str | None = self.config_entry.options.get( self._filter_product: str | None = self.config_entry.options.get(
CONF_FILTER_PRODUCT CONF_FILTER_PRODUCT
) )
async def _async_setup(self) -> None:
"""Initiate stations."""
try:
self.to_station = await self._train_api.async_get_train_station(
self.config_entry.data[CONF_TO]
)
self.from_station = await self._train_api.async_get_train_station(
self.config_entry.data[CONF_FROM]
)
except InvalidAuthentication as error:
raise ConfigEntryAuthFailed from error
except (NoTrainStationFound, MultipleTrainStationsFound) as error:
raise UpdateFailed(
f"Problem when trying station {self.config_entry.data[CONF_FROM]} to"
f" {self.config_entry.data[CONF_TO]}. Error: {error} "
) from error
async def _async_update_data(self) -> TrainData: async def _async_update_data(self) -> TrainData:
"""Fetch data from Trafikverket.""" """Fetch data from Trafikverket."""

View File

@ -38,7 +38,7 @@ async def load_integration_from_entry(
return_value=get_train_stop, return_value=get_train_stop,
), ),
patch( patch(
"homeassistant.components.trafikverket_train.TrafikverketTrain.async_get_train_station", "homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_get_train_station",
), ),
): ):
await hass.config_entries.async_setup(config_entry_id) await hass.config_entries.async_setup(config_entry_id)

View File

@ -499,7 +499,7 @@ async def test_options_flow(
with ( with (
patch( patch(
"homeassistant.components.trafikverket_train.TrafikverketTrain.async_get_train_station", "homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_get_train_station",
), ),
patch( patch(
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_get_next_train_stops", "homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_get_next_train_stops",

View File

@ -34,7 +34,7 @@ async def test_unload_entry(
with ( with (
patch( patch(
"homeassistant.components.trafikverket_train.TrafikverketTrain.async_get_train_station", "homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_get_train_station",
), ),
patch( patch(
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_get_next_train_stops", "homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_get_next_train_stops",
@ -69,7 +69,7 @@ async def test_auth_failed(
entry.add_to_hass(hass) entry.add_to_hass(hass)
with patch( with patch(
"homeassistant.components.trafikverket_train.TrafikverketTrain.async_get_train_station", "homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_get_train_station",
side_effect=InvalidAuthentication, side_effect=InvalidAuthentication,
): ):
await hass.config_entries.async_setup(entry.entry_id) await hass.config_entries.async_setup(entry.entry_id)
@ -99,7 +99,7 @@ async def test_no_stations(
entry.add_to_hass(hass) entry.add_to_hass(hass)
with patch( with patch(
"homeassistant.components.trafikverket_train.TrafikverketTrain.async_get_train_station", "homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_get_train_station",
side_effect=NoTrainStationFound, side_effect=NoTrainStationFound,
): ):
await hass.config_entries.async_setup(entry.entry_id) await hass.config_entries.async_setup(entry.entry_id)
@ -135,7 +135,7 @@ async def test_migrate_entity_unique_id(
with ( with (
patch( patch(
"homeassistant.components.trafikverket_train.TrafikverketTrain.async_get_train_station", "homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_get_train_station",
), ),
patch( patch(
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_get_next_train_stops", "homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_get_next_train_stops",