Add typing/async to NMBS (#139002)

* Add typing/async to NMBS

* Fix tests

* Boolean fields

* Update homeassistant/components/nmbs/sensor.py

Co-authored-by: Jorim Tielemans <tielemans.jorim@gmail.com>

---------

Co-authored-by: Shay Levy <levyshay1@gmail.com>
Co-authored-by: Jorim Tielemans <tielemans.jorim@gmail.com>
This commit is contained in:
Simon Lamon 2025-03-02 17:36:37 +01:00 committed by GitHub
parent de4540c68e
commit 40099547ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 101 additions and 126 deletions

View File

@ -8,6 +8,7 @@ from homeassistant.config_entries import ConfigEntry
from homeassistant.const import Platform from homeassistant.const import Platform
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.typing import ConfigType from homeassistant.helpers.typing import ConfigType
from .const import DOMAIN from .const import DOMAIN
@ -22,13 +23,13 @@ CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the NMBS component.""" """Set up the NMBS component."""
api_client = iRail() api_client = iRail(session=async_get_clientsession(hass))
hass.data.setdefault(DOMAIN, {}) hass.data.setdefault(DOMAIN, {})
station_response = await hass.async_add_executor_job(api_client.get_stations) station_response = await api_client.get_stations()
if station_response == -1: if station_response is None:
return False return False
hass.data[DOMAIN] = station_response["station"] hass.data[DOMAIN] = station_response.stations
return True return True

View File

@ -3,11 +3,13 @@
from typing import Any from typing import Any
from pyrail import iRail from pyrail import iRail
from pyrail.models import StationDetails
import voluptuous as vol import voluptuous as vol
from homeassistant.config_entries import ConfigFlow, ConfigFlowResult from homeassistant.config_entries import ConfigFlow, ConfigFlowResult
from homeassistant.const import Platform from homeassistant.const import Platform
from homeassistant.helpers import entity_registry as er from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.selector import ( from homeassistant.helpers.selector import (
BooleanSelector, BooleanSelector,
SelectOptionDict, SelectOptionDict,
@ -31,17 +33,15 @@ class NMBSConfigFlow(ConfigFlow, domain=DOMAIN):
def __init__(self) -> None: def __init__(self) -> None:
"""Initialize.""" """Initialize."""
self.api_client = iRail() self.stations: list[StationDetails] = []
self.stations: list[dict[str, Any]] = []
async def _fetch_stations(self) -> list[dict[str, Any]]: async def _fetch_stations(self) -> list[StationDetails]:
"""Fetch the stations.""" """Fetch the stations."""
stations_response = await self.hass.async_add_executor_job( api_client = iRail(session=async_get_clientsession(self.hass))
self.api_client.get_stations stations_response = await api_client.get_stations()
) if stations_response is None:
if stations_response == -1:
raise CannotConnect("The API is currently unavailable.") raise CannotConnect("The API is currently unavailable.")
return stations_response["station"] return stations_response.stations
async def _fetch_stations_choices(self) -> list[SelectOptionDict]: async def _fetch_stations_choices(self) -> list[SelectOptionDict]:
"""Fetch the stations options.""" """Fetch the stations options."""
@ -50,7 +50,7 @@ class NMBSConfigFlow(ConfigFlow, domain=DOMAIN):
self.stations = await self._fetch_stations() self.stations = await self._fetch_stations()
return [ return [
SelectOptionDict(value=station["id"], label=station["standardname"]) SelectOptionDict(value=station.id, label=station.standard_name)
for station in self.stations for station in self.stations
] ]
@ -72,12 +72,12 @@ class NMBSConfigFlow(ConfigFlow, domain=DOMAIN):
[station_from] = [ [station_from] = [
station station
for station in self.stations for station in self.stations
if station["id"] == user_input[CONF_STATION_FROM] if station.id == user_input[CONF_STATION_FROM]
] ]
[station_to] = [ [station_to] = [
station station
for station in self.stations for station in self.stations
if station["id"] == user_input[CONF_STATION_TO] if station.id == user_input[CONF_STATION_TO]
] ]
vias = "_excl_vias" if user_input.get(CONF_EXCLUDE_VIAS) else "" vias = "_excl_vias" if user_input.get(CONF_EXCLUDE_VIAS) else ""
await self.async_set_unique_id( await self.async_set_unique_id(
@ -85,7 +85,7 @@ class NMBSConfigFlow(ConfigFlow, domain=DOMAIN):
) )
self._abort_if_unique_id_configured() self._abort_if_unique_id_configured()
config_entry_name = f"Train from {station_from['standardname']} to {station_to['standardname']}" config_entry_name = f"Train from {station_from.standard_name} to {station_to.standard_name}"
return self.async_create_entry( return self.async_create_entry(
title=config_entry_name, title=config_entry_name,
data=user_input, data=user_input,
@ -127,18 +127,18 @@ class NMBSConfigFlow(ConfigFlow, domain=DOMAIN):
station_live = None station_live = None
for station in self.stations: for station in self.stations:
if user_input[CONF_STATION_FROM] in ( if user_input[CONF_STATION_FROM] in (
station["standardname"], station.standard_name,
station["name"], station.name,
): ):
station_from = station station_from = station
if user_input[CONF_STATION_TO] in ( if user_input[CONF_STATION_TO] in (
station["standardname"], station.standard_name,
station["name"], station.name,
): ):
station_to = station station_to = station
if CONF_STATION_LIVE in user_input and user_input[CONF_STATION_LIVE] in ( if CONF_STATION_LIVE in user_input and user_input[CONF_STATION_LIVE] in (
station["standardname"], station.standard_name,
station["name"], station.name,
): ):
station_live = station station_live = station
@ -148,29 +148,29 @@ class NMBSConfigFlow(ConfigFlow, domain=DOMAIN):
return self.async_abort(reason="same_station") return self.async_abort(reason="same_station")
# config flow uses id and not the standard name # config flow uses id and not the standard name
user_input[CONF_STATION_FROM] = station_from["id"] user_input[CONF_STATION_FROM] = station_from.id
user_input[CONF_STATION_TO] = station_to["id"] user_input[CONF_STATION_TO] = station_to.id
if station_live: if station_live:
user_input[CONF_STATION_LIVE] = station_live["id"] user_input[CONF_STATION_LIVE] = station_live.id
entity_registry = er.async_get(self.hass) entity_registry = er.async_get(self.hass)
prefix = "live" prefix = "live"
vias = "_excl_vias" if user_input.get(CONF_EXCLUDE_VIAS, False) else "" vias = "_excl_vias" if user_input.get(CONF_EXCLUDE_VIAS, False) else ""
if entity_id := entity_registry.async_get_entity_id( if entity_id := entity_registry.async_get_entity_id(
Platform.SENSOR, Platform.SENSOR,
DOMAIN, DOMAIN,
f"{prefix}_{station_live['standardname']}_{station_from['standardname']}_{station_to['standardname']}", f"{prefix}_{station_live.standard_name}_{station_from.standard_name}_{station_to.standard_name}",
): ):
new_unique_id = f"{DOMAIN}_{prefix}_{station_live['id']}_{station_from['id']}_{station_to['id']}{vias}" new_unique_id = f"{DOMAIN}_{prefix}_{station_live.id}_{station_from.id}_{station_to.id}{vias}"
entity_registry.async_update_entity( entity_registry.async_update_entity(
entity_id, new_unique_id=new_unique_id entity_id, new_unique_id=new_unique_id
) )
if entity_id := entity_registry.async_get_entity_id( if entity_id := entity_registry.async_get_entity_id(
Platform.SENSOR, Platform.SENSOR,
DOMAIN, DOMAIN,
f"{prefix}_{station_live['name']}_{station_from['name']}_{station_to['name']}", f"{prefix}_{station_live.name}_{station_from.name}_{station_to.name}",
): ):
new_unique_id = f"{DOMAIN}_{prefix}_{station_live['id']}_{station_from['id']}_{station_to['id']}{vias}" new_unique_id = f"{DOMAIN}_{prefix}_{station_live.id}_{station_from.id}_{station_to.id}{vias}"
entity_registry.async_update_entity( entity_registry.async_update_entity(
entity_id, new_unique_id=new_unique_id entity_id, new_unique_id=new_unique_id
) )

View File

@ -19,11 +19,7 @@ CONF_SHOW_ON_MAP = "show_on_map"
def find_station_by_name(hass: HomeAssistant, station_name: str): def find_station_by_name(hass: HomeAssistant, station_name: str):
"""Find given station_name in the station list.""" """Find given station_name in the station list."""
return next( return next(
( (s for s in hass.data[DOMAIN] if station_name in (s.standard_name, s.name)),
s
for s in hass.data[DOMAIN]
if station_name in (s["standardname"], s["name"])
),
None, None,
) )
@ -31,6 +27,6 @@ def find_station_by_name(hass: HomeAssistant, station_name: str):
def find_station(hass: HomeAssistant, station_name: str): def find_station(hass: HomeAssistant, station_name: str):
"""Find given station_id in the station list.""" """Find given station_id in the station list."""
return next( return next(
(s for s in hass.data[DOMAIN] if station_name in s["id"]), (s for s in hass.data[DOMAIN] if station_name in s.id),
None, None,
) )

View File

@ -7,5 +7,5 @@
"iot_class": "cloud_polling", "iot_class": "cloud_polling",
"loggers": ["pyrail"], "loggers": ["pyrail"],
"quality_scale": "legacy", "quality_scale": "legacy",
"requirements": ["pyrail==0.0.3"] "requirements": ["pyrail==0.4.1"]
} }

View File

@ -2,10 +2,12 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime
import logging import logging
from typing import Any from typing import Any
from pyrail import iRail from pyrail import iRail
from pyrail.models import ConnectionDetails, LiveboardDeparture, StationDetails
import voluptuous as vol import voluptuous as vol
from homeassistant.components.sensor import ( from homeassistant.components.sensor import (
@ -23,6 +25,7 @@ from homeassistant.const import (
) )
from homeassistant.core import DOMAIN as HOMEASSISTANT_DOMAIN, HomeAssistant from homeassistant.core import DOMAIN as HOMEASSISTANT_DOMAIN, HomeAssistant
from homeassistant.helpers import config_validation as cv from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.entity_platform import ( from homeassistant.helpers.entity_platform import (
AddConfigEntryEntitiesCallback, AddConfigEntryEntitiesCallback,
AddEntitiesCallback, AddEntitiesCallback,
@ -44,8 +47,6 @@ from .const import ( # noqa: F401
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
API_FAILURE = -1
DEFAULT_NAME = "NMBS" DEFAULT_NAME = "NMBS"
DEFAULT_ICON = "mdi:train" DEFAULT_ICON = "mdi:train"
@ -63,12 +64,12 @@ PLATFORM_SCHEMA = SENSOR_PLATFORM_SCHEMA.extend(
) )
def get_time_until(departure_time=None): def get_time_until(departure_time: datetime | None = None):
"""Calculate the time between now and a train's departure time.""" """Calculate the time between now and a train's departure time."""
if departure_time is None: if departure_time is None:
return 0 return 0
delta = dt_util.utc_from_timestamp(int(departure_time)) - dt_util.now() delta = dt_util.as_utc(departure_time) - dt_util.utcnow()
return round(delta.total_seconds() / 60) return round(delta.total_seconds() / 60)
@ -77,11 +78,9 @@ def get_delay_in_minutes(delay=0):
return round(int(delay) / 60) return round(int(delay) / 60)
def get_ride_duration(departure_time, arrival_time, delay=0): def get_ride_duration(departure_time: datetime, arrival_time: datetime, delay=0):
"""Calculate the total travel time in minutes.""" """Calculate the total travel time in minutes."""
duration = dt_util.utc_from_timestamp( duration = arrival_time - departure_time
int(arrival_time)
) - dt_util.utc_from_timestamp(int(departure_time))
duration_time = int(round(duration.total_seconds() / 60)) duration_time = int(round(duration.total_seconds() / 60))
return duration_time + get_delay_in_minutes(delay) return duration_time + get_delay_in_minutes(delay)
@ -157,7 +156,7 @@ async def async_setup_entry(
async_add_entities: AddConfigEntryEntitiesCallback, async_add_entities: AddConfigEntryEntitiesCallback,
) -> None: ) -> None:
"""Set up NMBS sensor entities based on a config entry.""" """Set up NMBS sensor entities based on a config entry."""
api_client = iRail() api_client = iRail(session=async_get_clientsession(hass))
name = config_entry.data.get(CONF_NAME, None) name = config_entry.data.get(CONF_NAME, None)
show_on_map = config_entry.data.get(CONF_SHOW_ON_MAP, False) show_on_map = config_entry.data.get(CONF_SHOW_ON_MAP, False)
@ -189,9 +188,9 @@ class NMBSLiveBoard(SensorEntity):
def __init__( def __init__(
self, self,
api_client: iRail, api_client: iRail,
live_station: dict[str, Any], live_station: StationDetails,
station_from: dict[str, Any], station_from: StationDetails,
station_to: dict[str, Any], station_to: StationDetails,
excl_vias: bool, excl_vias: bool,
) -> None: ) -> None:
"""Initialize the sensor for getting liveboard data.""" """Initialize the sensor for getting liveboard data."""
@ -201,7 +200,8 @@ class NMBSLiveBoard(SensorEntity):
self._station_to = station_to self._station_to = station_to
self._excl_vias = excl_vias self._excl_vias = excl_vias
self._attrs: dict[str, Any] | None = {} self._attrs: LiveboardDeparture | None = None
self._state: str | None = None self._state: str | None = None
self.entity_registry_enabled_default = False self.entity_registry_enabled_default = False
@ -209,22 +209,20 @@ class NMBSLiveBoard(SensorEntity):
@property @property
def name(self) -> str: def name(self) -> str:
"""Return the sensor default name.""" """Return the sensor default name."""
return f"Trains in {self._station['standardname']}" return f"Trains in {self._station.standard_name}"
@property @property
def unique_id(self) -> str: def unique_id(self) -> str:
"""Return the unique ID.""" """Return the unique ID."""
unique_id = ( unique_id = f"{self._station.id}_{self._station_from.id}_{self._station_to.id}"
f"{self._station['id']}_{self._station_from['id']}_{self._station_to['id']}"
)
vias = "_excl_vias" if self._excl_vias else "" vias = "_excl_vias" if self._excl_vias else ""
return f"nmbs_live_{unique_id}{vias}" return f"nmbs_live_{unique_id}{vias}"
@property @property
def icon(self) -> str: def icon(self) -> str:
"""Return the default icon or an alert icon if delays.""" """Return the default icon or an alert icon if delays."""
if self._attrs and int(self._attrs["delay"]) > 0: if self._attrs and int(self._attrs.delay) > 0:
return DEFAULT_ICON_ALERT return DEFAULT_ICON_ALERT
return DEFAULT_ICON return DEFAULT_ICON
@ -240,15 +238,15 @@ class NMBSLiveBoard(SensorEntity):
if self._state is None or not self._attrs: if self._state is None or not self._attrs:
return None return None
delay = get_delay_in_minutes(self._attrs["delay"]) delay = get_delay_in_minutes(self._attrs.delay)
departure = get_time_until(self._attrs["time"]) departure = get_time_until(self._attrs.time)
attrs = { attrs = {
"departure": f"In {departure} minutes", "departure": f"In {departure} minutes",
"departure_minutes": departure, "departure_minutes": departure,
"extra_train": int(self._attrs["isExtra"]) > 0, "extra_train": self._attrs.is_extra,
"vehicle_id": self._attrs["vehicle"], "vehicle_id": self._attrs.vehicle,
"monitored_station": self._station["standardname"], "monitored_station": self._station.standard_name,
} }
if delay > 0: if delay > 0:
@ -257,28 +255,26 @@ class NMBSLiveBoard(SensorEntity):
return attrs return attrs
def update(self) -> None: async def async_update(self, **kwargs: Any) -> None:
"""Set the state equal to the next departure.""" """Set the state equal to the next departure."""
liveboard = self._api_client.get_liveboard(self._station["id"]) liveboard = await self._api_client.get_liveboard(self._station.id)
if liveboard == API_FAILURE: if liveboard is None:
_LOGGER.warning("API failed in NMBSLiveBoard") _LOGGER.warning("API failed in NMBSLiveBoard")
return return
if not (departures := liveboard.get("departures")): if not (departures := liveboard.departures):
_LOGGER.warning("API returned invalid departures: %r", liveboard) _LOGGER.warning("API returned invalid departures: %r", liveboard)
return return
_LOGGER.debug("API returned departures: %r", departures) _LOGGER.debug("API returned departures: %r", departures)
if departures["number"] == "0": if len(departures) == 0:
# No trains are scheduled # No trains are scheduled
return return
next_departure = departures["departure"][0] next_departure = departures[0]
self._attrs = next_departure self._attrs = next_departure
self._state = ( self._state = f"Track {next_departure.platform} - {next_departure.station}"
f"Track {next_departure['platform']} - {next_departure['station']}"
)
class NMBSSensor(SensorEntity): class NMBSSensor(SensorEntity):
@ -292,8 +288,8 @@ class NMBSSensor(SensorEntity):
api_client: iRail, api_client: iRail,
name: str, name: str,
show_on_map: bool, show_on_map: bool,
station_from: dict[str, Any], station_from: StationDetails,
station_to: dict[str, Any], station_to: StationDetails,
excl_vias: bool, excl_vias: bool,
) -> None: ) -> None:
"""Initialize the NMBS connection sensor.""" """Initialize the NMBS connection sensor."""
@ -304,13 +300,13 @@ class NMBSSensor(SensorEntity):
self._station_to = station_to self._station_to = station_to
self._excl_vias = excl_vias self._excl_vias = excl_vias
self._attrs: dict[str, Any] | None = {} self._attrs: ConnectionDetails | None = None
self._state = None self._state = None
@property @property
def unique_id(self) -> str: def unique_id(self) -> str:
"""Return the unique ID.""" """Return the unique ID."""
unique_id = f"{self._station_from['id']}_{self._station_to['id']}" unique_id = f"{self._station_from.id}_{self._station_to.id}"
vias = "_excl_vias" if self._excl_vias else "" vias = "_excl_vias" if self._excl_vias else ""
return f"nmbs_connection_{unique_id}{vias}" return f"nmbs_connection_{unique_id}{vias}"
@ -319,14 +315,14 @@ class NMBSSensor(SensorEntity):
def name(self) -> str: def name(self) -> str:
"""Return the name of the sensor.""" """Return the name of the sensor."""
if self._name is None: if self._name is None:
return f"Train from {self._station_from['standardname']} to {self._station_to['standardname']}" return f"Train from {self._station_from.standard_name} to {self._station_to.standard_name}"
return self._name return self._name
@property @property
def icon(self) -> str: def icon(self) -> str:
"""Return the sensor default icon or an alert icon if any delay.""" """Return the sensor default icon or an alert icon if any delay."""
if self._attrs: if self._attrs:
delay = get_delay_in_minutes(self._attrs["departure"]["delay"]) delay = get_delay_in_minutes(self._attrs.departure.delay)
if delay > 0: if delay > 0:
return "mdi:alert-octagon" return "mdi:alert-octagon"
@ -338,19 +334,19 @@ class NMBSSensor(SensorEntity):
if self._state is None or not self._attrs: if self._state is None or not self._attrs:
return None return None
delay = get_delay_in_minutes(self._attrs["departure"]["delay"]) delay = get_delay_in_minutes(self._attrs.departure.delay)
departure = get_time_until(self._attrs["departure"]["time"]) departure = get_time_until(self._attrs.departure.time)
canceled = int(self._attrs["departure"]["canceled"]) canceled = self._attrs.departure.canceled
attrs = { attrs = {
"destination": self._attrs["departure"]["station"], "destination": self._attrs.departure.station,
"direction": self._attrs["departure"]["direction"]["name"], "direction": self._attrs.departure.direction.name,
"platform_arriving": self._attrs["arrival"]["platform"], "platform_arriving": self._attrs.arrival.platform,
"platform_departing": self._attrs["departure"]["platform"], "platform_departing": self._attrs.departure.platform,
"vehicle_id": self._attrs["departure"]["vehicle"], "vehicle_id": self._attrs.departure.vehicle,
} }
if canceled != 1: if not canceled:
attrs["departure"] = f"In {departure} minutes" attrs["departure"] = f"In {departure} minutes"
attrs["departure_minutes"] = departure attrs["departure_minutes"] = departure
attrs["canceled"] = False attrs["canceled"] = False
@ -364,14 +360,14 @@ class NMBSSensor(SensorEntity):
attrs[ATTR_LONGITUDE] = self.station_coordinates[1] attrs[ATTR_LONGITUDE] = self.station_coordinates[1]
if self.is_via_connection and not self._excl_vias: if self.is_via_connection and not self._excl_vias:
via = self._attrs["vias"]["via"][0] via = self._attrs.vias.via[0]
attrs["via"] = via["station"] attrs["via"] = via.station
attrs["via_arrival_platform"] = via["arrival"]["platform"] attrs["via_arrival_platform"] = via.arrival.platform
attrs["via_transfer_platform"] = via["departure"]["platform"] attrs["via_transfer_platform"] = via.departure.platform
attrs["via_transfer_time"] = get_delay_in_minutes( attrs["via_transfer_time"] = get_delay_in_minutes(
via["timebetween"] via.timebetween
) + get_delay_in_minutes(via["departure"]["delay"]) ) + get_delay_in_minutes(via.departure.delay)
if delay > 0: if delay > 0:
attrs["delay"] = f"{delay} minutes" attrs["delay"] = f"{delay} minutes"
@ -390,8 +386,8 @@ class NMBSSensor(SensorEntity):
if self._state is None or not self._attrs: if self._state is None or not self._attrs:
return [] return []
latitude = float(self._attrs["departure"]["stationinfo"]["locationY"]) latitude = float(self._attrs.departure.station_info.latitude)
longitude = float(self._attrs["departure"]["stationinfo"]["locationX"]) longitude = float(self._attrs.departure.station_info.longitude)
return [latitude, longitude] return [latitude, longitude]
@property @property
@ -400,24 +396,24 @@ class NMBSSensor(SensorEntity):
if not self._attrs: if not self._attrs:
return False return False
return "vias" in self._attrs and int(self._attrs["vias"]["number"]) > 0 return self._attrs.vias is not None and len(self._attrs.vias) > 0
def update(self) -> None: async def async_update(self, **kwargs: Any) -> None:
"""Set the state to the duration of a connection.""" """Set the state to the duration of a connection."""
connections = self._api_client.get_connections( connections = await self._api_client.get_connections(
self._station_from["id"], self._station_to["id"] self._station_from.id, self._station_to.id
) )
if connections == API_FAILURE: if connections is None:
_LOGGER.warning("API failed in NMBSSensor") _LOGGER.warning("API failed in NMBSSensor")
return return
if not (connection := connections.get("connection")): if not (connection := connections.connections):
_LOGGER.warning("API returned invalid connection: %r", connections) _LOGGER.warning("API returned invalid connection: %r", connections)
return return
_LOGGER.debug("API returned connection: %r", connection) _LOGGER.debug("API returned connection: %r", connection)
if int(connection[0]["departure"]["left"]) > 0: if connection[0].departure.left:
next_connection = connection[1] next_connection = connection[1]
else: else:
next_connection = connection[0] next_connection = connection[0]
@ -431,9 +427,9 @@ class NMBSSensor(SensorEntity):
return return
duration = get_ride_duration( duration = get_ride_duration(
next_connection["departure"]["time"], next_connection.departure.time,
next_connection["arrival"]["time"], next_connection.arrival.time,
next_connection["departure"]["delay"], next_connection.departure.delay,
) )
self._state = duration self._state = duration

2
requirements_all.txt generated
View File

@ -2244,7 +2244,7 @@ pyqvrpro==0.52
pyqwikswitch==0.93 pyqwikswitch==0.93
# homeassistant.components.nmbs # homeassistant.components.nmbs
pyrail==0.0.3 pyrail==0.4.1
# homeassistant.components.rainbird # homeassistant.components.rainbird
pyrainbird==6.0.1 pyrainbird==6.0.1

View File

@ -1834,7 +1834,7 @@ pyps4-2ndscreen==1.3.1
pyqwikswitch==0.93 pyqwikswitch==0.93
# homeassistant.components.nmbs # homeassistant.components.nmbs
pyrail==0.0.3 pyrail==0.4.1
# homeassistant.components.rainbird # homeassistant.components.rainbird
pyrainbird==6.0.1 pyrainbird==6.0.1

View File

@ -1,20 +1 @@
"""Tests for the NMBS integration.""" """Tests for the NMBS integration."""
import json
from typing import Any
from tests.common import load_fixture
def mock_api_unavailable() -> dict[str, Any]:
"""Mock for unavailable api."""
return -1
def mock_station_response() -> dict[str, Any]:
"""Mock for valid station response."""
dummy_stations_response: dict[str, Any] = json.loads(
load_fixture("stations.json", "nmbs")
)
return dummy_stations_response

View File

@ -3,6 +3,7 @@
from collections.abc import Generator from collections.abc import Generator
from unittest.mock import AsyncMock, patch from unittest.mock import AsyncMock, patch
from pyrail.models import StationsApiResponse
import pytest import pytest
from homeassistant.components.nmbs.const import ( from homeassistant.components.nmbs.const import (
@ -38,8 +39,8 @@ def mock_nmbs_client() -> Generator[AsyncMock]:
), ),
): ):
client = mock_client.return_value client = mock_client.return_value
client.get_stations.return_value = load_json_object_fixture( client.get_stations.return_value = StationsApiResponse.from_dict(
"stations.json", DOMAIN load_json_object_fixture("stations.json", DOMAIN)
) )
yield client yield client

View File

@ -142,7 +142,7 @@ async def test_unavailable_api(
hass: HomeAssistant, mock_nmbs_client: AsyncMock hass: HomeAssistant, mock_nmbs_client: AsyncMock
) -> None: ) -> None:
"""Test starting a flow by user and api is unavailable.""" """Test starting a flow by user and api is unavailable."""
mock_nmbs_client.get_stations.return_value = -1 mock_nmbs_client.get_stations.return_value = None
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(
DOMAIN, DOMAIN,
context={"source": config_entries.SOURCE_USER}, context={"source": config_entries.SOURCE_USER},
@ -203,7 +203,7 @@ async def test_unavailable_api_import(
hass: HomeAssistant, mock_nmbs_client: AsyncMock hass: HomeAssistant, mock_nmbs_client: AsyncMock
) -> None: ) -> None:
"""Test starting a flow by import and api is unavailable.""" """Test starting a flow by import and api is unavailable."""
mock_nmbs_client.get_stations.return_value = -1 mock_nmbs_client.get_stations.return_value = None
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(
DOMAIN, DOMAIN,
context={"source": SOURCE_IMPORT}, context={"source": SOURCE_IMPORT},