mirror of
https://github.com/home-assistant/core.git
synced 2025-04-19 14:57:52 +00:00
Add typing/async to NMBS (#139002)
* Add typing/async to NMBS * Fix tests * Boolean fields * Update homeassistant/components/nmbs/sensor.py Co-authored-by: Jorim Tielemans <tielemans.jorim@gmail.com> --------- Co-authored-by: Shay Levy <levyshay1@gmail.com> Co-authored-by: Jorim Tielemans <tielemans.jorim@gmail.com>
This commit is contained in:
parent
de4540c68e
commit
40099547ef
@ -8,6 +8,7 @@ from homeassistant.config_entries import ConfigEntry
|
||||
from homeassistant.const import Platform
|
||||
from homeassistant.core import HomeAssistant
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
from homeassistant.helpers.typing import ConfigType
|
||||
|
||||
from .const import DOMAIN
|
||||
@ -22,13 +23,13 @@ CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)
|
||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||
"""Set up the NMBS component."""
|
||||
|
||||
api_client = iRail()
|
||||
api_client = iRail(session=async_get_clientsession(hass))
|
||||
|
||||
hass.data.setdefault(DOMAIN, {})
|
||||
station_response = await hass.async_add_executor_job(api_client.get_stations)
|
||||
if station_response == -1:
|
||||
station_response = await api_client.get_stations()
|
||||
if station_response is None:
|
||||
return False
|
||||
hass.data[DOMAIN] = station_response["station"]
|
||||
hass.data[DOMAIN] = station_response.stations
|
||||
|
||||
return True
|
||||
|
||||
|
@ -3,11 +3,13 @@
|
||||
from typing import Any
|
||||
|
||||
from pyrail import iRail
|
||||
from pyrail.models import StationDetails
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.config_entries import ConfigFlow, ConfigFlowResult
|
||||
from homeassistant.const import Platform
|
||||
from homeassistant.helpers import entity_registry as er
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
from homeassistant.helpers.selector import (
|
||||
BooleanSelector,
|
||||
SelectOptionDict,
|
||||
@ -31,17 +33,15 @@ class NMBSConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize."""
|
||||
self.api_client = iRail()
|
||||
self.stations: list[dict[str, Any]] = []
|
||||
self.stations: list[StationDetails] = []
|
||||
|
||||
async def _fetch_stations(self) -> list[dict[str, Any]]:
|
||||
async def _fetch_stations(self) -> list[StationDetails]:
|
||||
"""Fetch the stations."""
|
||||
stations_response = await self.hass.async_add_executor_job(
|
||||
self.api_client.get_stations
|
||||
)
|
||||
if stations_response == -1:
|
||||
api_client = iRail(session=async_get_clientsession(self.hass))
|
||||
stations_response = await api_client.get_stations()
|
||||
if stations_response is None:
|
||||
raise CannotConnect("The API is currently unavailable.")
|
||||
return stations_response["station"]
|
||||
return stations_response.stations
|
||||
|
||||
async def _fetch_stations_choices(self) -> list[SelectOptionDict]:
|
||||
"""Fetch the stations options."""
|
||||
@ -50,7 +50,7 @@ class NMBSConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
self.stations = await self._fetch_stations()
|
||||
|
||||
return [
|
||||
SelectOptionDict(value=station["id"], label=station["standardname"])
|
||||
SelectOptionDict(value=station.id, label=station.standard_name)
|
||||
for station in self.stations
|
||||
]
|
||||
|
||||
@ -72,12 +72,12 @@ class NMBSConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
[station_from] = [
|
||||
station
|
||||
for station in self.stations
|
||||
if station["id"] == user_input[CONF_STATION_FROM]
|
||||
if station.id == user_input[CONF_STATION_FROM]
|
||||
]
|
||||
[station_to] = [
|
||||
station
|
||||
for station in self.stations
|
||||
if station["id"] == user_input[CONF_STATION_TO]
|
||||
if station.id == user_input[CONF_STATION_TO]
|
||||
]
|
||||
vias = "_excl_vias" if user_input.get(CONF_EXCLUDE_VIAS) else ""
|
||||
await self.async_set_unique_id(
|
||||
@ -85,7 +85,7 @@ class NMBSConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
)
|
||||
self._abort_if_unique_id_configured()
|
||||
|
||||
config_entry_name = f"Train from {station_from['standardname']} to {station_to['standardname']}"
|
||||
config_entry_name = f"Train from {station_from.standard_name} to {station_to.standard_name}"
|
||||
return self.async_create_entry(
|
||||
title=config_entry_name,
|
||||
data=user_input,
|
||||
@ -127,18 +127,18 @@ class NMBSConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
station_live = None
|
||||
for station in self.stations:
|
||||
if user_input[CONF_STATION_FROM] in (
|
||||
station["standardname"],
|
||||
station["name"],
|
||||
station.standard_name,
|
||||
station.name,
|
||||
):
|
||||
station_from = station
|
||||
if user_input[CONF_STATION_TO] in (
|
||||
station["standardname"],
|
||||
station["name"],
|
||||
station.standard_name,
|
||||
station.name,
|
||||
):
|
||||
station_to = station
|
||||
if CONF_STATION_LIVE in user_input and user_input[CONF_STATION_LIVE] in (
|
||||
station["standardname"],
|
||||
station["name"],
|
||||
station.standard_name,
|
||||
station.name,
|
||||
):
|
||||
station_live = station
|
||||
|
||||
@ -148,29 +148,29 @@ class NMBSConfigFlow(ConfigFlow, domain=DOMAIN):
|
||||
return self.async_abort(reason="same_station")
|
||||
|
||||
# config flow uses id and not the standard name
|
||||
user_input[CONF_STATION_FROM] = station_from["id"]
|
||||
user_input[CONF_STATION_TO] = station_to["id"]
|
||||
user_input[CONF_STATION_FROM] = station_from.id
|
||||
user_input[CONF_STATION_TO] = station_to.id
|
||||
|
||||
if station_live:
|
||||
user_input[CONF_STATION_LIVE] = station_live["id"]
|
||||
user_input[CONF_STATION_LIVE] = station_live.id
|
||||
entity_registry = er.async_get(self.hass)
|
||||
prefix = "live"
|
||||
vias = "_excl_vias" if user_input.get(CONF_EXCLUDE_VIAS, False) else ""
|
||||
if entity_id := entity_registry.async_get_entity_id(
|
||||
Platform.SENSOR,
|
||||
DOMAIN,
|
||||
f"{prefix}_{station_live['standardname']}_{station_from['standardname']}_{station_to['standardname']}",
|
||||
f"{prefix}_{station_live.standard_name}_{station_from.standard_name}_{station_to.standard_name}",
|
||||
):
|
||||
new_unique_id = f"{DOMAIN}_{prefix}_{station_live['id']}_{station_from['id']}_{station_to['id']}{vias}"
|
||||
new_unique_id = f"{DOMAIN}_{prefix}_{station_live.id}_{station_from.id}_{station_to.id}{vias}"
|
||||
entity_registry.async_update_entity(
|
||||
entity_id, new_unique_id=new_unique_id
|
||||
)
|
||||
if entity_id := entity_registry.async_get_entity_id(
|
||||
Platform.SENSOR,
|
||||
DOMAIN,
|
||||
f"{prefix}_{station_live['name']}_{station_from['name']}_{station_to['name']}",
|
||||
f"{prefix}_{station_live.name}_{station_from.name}_{station_to.name}",
|
||||
):
|
||||
new_unique_id = f"{DOMAIN}_{prefix}_{station_live['id']}_{station_from['id']}_{station_to['id']}{vias}"
|
||||
new_unique_id = f"{DOMAIN}_{prefix}_{station_live.id}_{station_from.id}_{station_to.id}{vias}"
|
||||
entity_registry.async_update_entity(
|
||||
entity_id, new_unique_id=new_unique_id
|
||||
)
|
||||
|
@ -19,11 +19,7 @@ CONF_SHOW_ON_MAP = "show_on_map"
|
||||
def find_station_by_name(hass: HomeAssistant, station_name: str):
|
||||
"""Find given station_name in the station list."""
|
||||
return next(
|
||||
(
|
||||
s
|
||||
for s in hass.data[DOMAIN]
|
||||
if station_name in (s["standardname"], s["name"])
|
||||
),
|
||||
(s for s in hass.data[DOMAIN] if station_name in (s.standard_name, s.name)),
|
||||
None,
|
||||
)
|
||||
|
||||
@ -31,6 +27,6 @@ def find_station_by_name(hass: HomeAssistant, station_name: str):
|
||||
def find_station(hass: HomeAssistant, station_name: str):
|
||||
"""Find given station_id in the station list."""
|
||||
return next(
|
||||
(s for s in hass.data[DOMAIN] if station_name in s["id"]),
|
||||
(s for s in hass.data[DOMAIN] if station_name in s.id),
|
||||
None,
|
||||
)
|
||||
|
@ -7,5 +7,5 @@
|
||||
"iot_class": "cloud_polling",
|
||||
"loggers": ["pyrail"],
|
||||
"quality_scale": "legacy",
|
||||
"requirements": ["pyrail==0.0.3"]
|
||||
"requirements": ["pyrail==0.4.1"]
|
||||
}
|
||||
|
@ -2,10 +2,12 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from pyrail import iRail
|
||||
from pyrail.models import ConnectionDetails, LiveboardDeparture, StationDetails
|
||||
import voluptuous as vol
|
||||
|
||||
from homeassistant.components.sensor import (
|
||||
@ -23,6 +25,7 @@ from homeassistant.const import (
|
||||
)
|
||||
from homeassistant.core import DOMAIN as HOMEASSISTANT_DOMAIN, HomeAssistant
|
||||
from homeassistant.helpers import config_validation as cv
|
||||
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||
from homeassistant.helpers.entity_platform import (
|
||||
AddConfigEntryEntitiesCallback,
|
||||
AddEntitiesCallback,
|
||||
@ -44,8 +47,6 @@ from .const import ( # noqa: F401
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
API_FAILURE = -1
|
||||
|
||||
DEFAULT_NAME = "NMBS"
|
||||
|
||||
DEFAULT_ICON = "mdi:train"
|
||||
@ -63,12 +64,12 @@ PLATFORM_SCHEMA = SENSOR_PLATFORM_SCHEMA.extend(
|
||||
)
|
||||
|
||||
|
||||
def get_time_until(departure_time=None):
|
||||
def get_time_until(departure_time: datetime | None = None):
|
||||
"""Calculate the time between now and a train's departure time."""
|
||||
if departure_time is None:
|
||||
return 0
|
||||
|
||||
delta = dt_util.utc_from_timestamp(int(departure_time)) - dt_util.now()
|
||||
delta = dt_util.as_utc(departure_time) - dt_util.utcnow()
|
||||
return round(delta.total_seconds() / 60)
|
||||
|
||||
|
||||
@ -77,11 +78,9 @@ def get_delay_in_minutes(delay=0):
|
||||
return round(int(delay) / 60)
|
||||
|
||||
|
||||
def get_ride_duration(departure_time, arrival_time, delay=0):
|
||||
def get_ride_duration(departure_time: datetime, arrival_time: datetime, delay=0):
|
||||
"""Calculate the total travel time in minutes."""
|
||||
duration = dt_util.utc_from_timestamp(
|
||||
int(arrival_time)
|
||||
) - dt_util.utc_from_timestamp(int(departure_time))
|
||||
duration = arrival_time - departure_time
|
||||
duration_time = int(round(duration.total_seconds() / 60))
|
||||
return duration_time + get_delay_in_minutes(delay)
|
||||
|
||||
@ -157,7 +156,7 @@ async def async_setup_entry(
|
||||
async_add_entities: AddConfigEntryEntitiesCallback,
|
||||
) -> None:
|
||||
"""Set up NMBS sensor entities based on a config entry."""
|
||||
api_client = iRail()
|
||||
api_client = iRail(session=async_get_clientsession(hass))
|
||||
|
||||
name = config_entry.data.get(CONF_NAME, None)
|
||||
show_on_map = config_entry.data.get(CONF_SHOW_ON_MAP, False)
|
||||
@ -189,9 +188,9 @@ class NMBSLiveBoard(SensorEntity):
|
||||
def __init__(
|
||||
self,
|
||||
api_client: iRail,
|
||||
live_station: dict[str, Any],
|
||||
station_from: dict[str, Any],
|
||||
station_to: dict[str, Any],
|
||||
live_station: StationDetails,
|
||||
station_from: StationDetails,
|
||||
station_to: StationDetails,
|
||||
excl_vias: bool,
|
||||
) -> None:
|
||||
"""Initialize the sensor for getting liveboard data."""
|
||||
@ -201,7 +200,8 @@ class NMBSLiveBoard(SensorEntity):
|
||||
self._station_to = station_to
|
||||
|
||||
self._excl_vias = excl_vias
|
||||
self._attrs: dict[str, Any] | None = {}
|
||||
self._attrs: LiveboardDeparture | None = None
|
||||
|
||||
self._state: str | None = None
|
||||
|
||||
self.entity_registry_enabled_default = False
|
||||
@ -209,22 +209,20 @@ class NMBSLiveBoard(SensorEntity):
|
||||
@property
|
||||
def name(self) -> str:
|
||||
"""Return the sensor default name."""
|
||||
return f"Trains in {self._station['standardname']}"
|
||||
return f"Trains in {self._station.standard_name}"
|
||||
|
||||
@property
|
||||
def unique_id(self) -> str:
|
||||
"""Return the unique ID."""
|
||||
|
||||
unique_id = (
|
||||
f"{self._station['id']}_{self._station_from['id']}_{self._station_to['id']}"
|
||||
)
|
||||
unique_id = f"{self._station.id}_{self._station_from.id}_{self._station_to.id}"
|
||||
vias = "_excl_vias" if self._excl_vias else ""
|
||||
return f"nmbs_live_{unique_id}{vias}"
|
||||
|
||||
@property
|
||||
def icon(self) -> str:
|
||||
"""Return the default icon or an alert icon if delays."""
|
||||
if self._attrs and int(self._attrs["delay"]) > 0:
|
||||
if self._attrs and int(self._attrs.delay) > 0:
|
||||
return DEFAULT_ICON_ALERT
|
||||
|
||||
return DEFAULT_ICON
|
||||
@ -240,15 +238,15 @@ class NMBSLiveBoard(SensorEntity):
|
||||
if self._state is None or not self._attrs:
|
||||
return None
|
||||
|
||||
delay = get_delay_in_minutes(self._attrs["delay"])
|
||||
departure = get_time_until(self._attrs["time"])
|
||||
delay = get_delay_in_minutes(self._attrs.delay)
|
||||
departure = get_time_until(self._attrs.time)
|
||||
|
||||
attrs = {
|
||||
"departure": f"In {departure} minutes",
|
||||
"departure_minutes": departure,
|
||||
"extra_train": int(self._attrs["isExtra"]) > 0,
|
||||
"vehicle_id": self._attrs["vehicle"],
|
||||
"monitored_station": self._station["standardname"],
|
||||
"extra_train": self._attrs.is_extra,
|
||||
"vehicle_id": self._attrs.vehicle,
|
||||
"monitored_station": self._station.standard_name,
|
||||
}
|
||||
|
||||
if delay > 0:
|
||||
@ -257,28 +255,26 @@ class NMBSLiveBoard(SensorEntity):
|
||||
|
||||
return attrs
|
||||
|
||||
def update(self) -> None:
|
||||
async def async_update(self, **kwargs: Any) -> None:
|
||||
"""Set the state equal to the next departure."""
|
||||
liveboard = self._api_client.get_liveboard(self._station["id"])
|
||||
liveboard = await self._api_client.get_liveboard(self._station.id)
|
||||
|
||||
if liveboard == API_FAILURE:
|
||||
if liveboard is None:
|
||||
_LOGGER.warning("API failed in NMBSLiveBoard")
|
||||
return
|
||||
|
||||
if not (departures := liveboard.get("departures")):
|
||||
if not (departures := liveboard.departures):
|
||||
_LOGGER.warning("API returned invalid departures: %r", liveboard)
|
||||
return
|
||||
|
||||
_LOGGER.debug("API returned departures: %r", departures)
|
||||
if departures["number"] == "0":
|
||||
if len(departures) == 0:
|
||||
# No trains are scheduled
|
||||
return
|
||||
next_departure = departures["departure"][0]
|
||||
next_departure = departures[0]
|
||||
|
||||
self._attrs = next_departure
|
||||
self._state = (
|
||||
f"Track {next_departure['platform']} - {next_departure['station']}"
|
||||
)
|
||||
self._state = f"Track {next_departure.platform} - {next_departure.station}"
|
||||
|
||||
|
||||
class NMBSSensor(SensorEntity):
|
||||
@ -292,8 +288,8 @@ class NMBSSensor(SensorEntity):
|
||||
api_client: iRail,
|
||||
name: str,
|
||||
show_on_map: bool,
|
||||
station_from: dict[str, Any],
|
||||
station_to: dict[str, Any],
|
||||
station_from: StationDetails,
|
||||
station_to: StationDetails,
|
||||
excl_vias: bool,
|
||||
) -> None:
|
||||
"""Initialize the NMBS connection sensor."""
|
||||
@ -304,13 +300,13 @@ class NMBSSensor(SensorEntity):
|
||||
self._station_to = station_to
|
||||
self._excl_vias = excl_vias
|
||||
|
||||
self._attrs: dict[str, Any] | None = {}
|
||||
self._attrs: ConnectionDetails | None = None
|
||||
self._state = None
|
||||
|
||||
@property
|
||||
def unique_id(self) -> str:
|
||||
"""Return the unique ID."""
|
||||
unique_id = f"{self._station_from['id']}_{self._station_to['id']}"
|
||||
unique_id = f"{self._station_from.id}_{self._station_to.id}"
|
||||
|
||||
vias = "_excl_vias" if self._excl_vias else ""
|
||||
return f"nmbs_connection_{unique_id}{vias}"
|
||||
@ -319,14 +315,14 @@ class NMBSSensor(SensorEntity):
|
||||
def name(self) -> str:
|
||||
"""Return the name of the sensor."""
|
||||
if self._name is None:
|
||||
return f"Train from {self._station_from['standardname']} to {self._station_to['standardname']}"
|
||||
return f"Train from {self._station_from.standard_name} to {self._station_to.standard_name}"
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def icon(self) -> str:
|
||||
"""Return the sensor default icon or an alert icon if any delay."""
|
||||
if self._attrs:
|
||||
delay = get_delay_in_minutes(self._attrs["departure"]["delay"])
|
||||
delay = get_delay_in_minutes(self._attrs.departure.delay)
|
||||
if delay > 0:
|
||||
return "mdi:alert-octagon"
|
||||
|
||||
@ -338,19 +334,19 @@ class NMBSSensor(SensorEntity):
|
||||
if self._state is None or not self._attrs:
|
||||
return None
|
||||
|
||||
delay = get_delay_in_minutes(self._attrs["departure"]["delay"])
|
||||
departure = get_time_until(self._attrs["departure"]["time"])
|
||||
canceled = int(self._attrs["departure"]["canceled"])
|
||||
delay = get_delay_in_minutes(self._attrs.departure.delay)
|
||||
departure = get_time_until(self._attrs.departure.time)
|
||||
canceled = self._attrs.departure.canceled
|
||||
|
||||
attrs = {
|
||||
"destination": self._attrs["departure"]["station"],
|
||||
"direction": self._attrs["departure"]["direction"]["name"],
|
||||
"platform_arriving": self._attrs["arrival"]["platform"],
|
||||
"platform_departing": self._attrs["departure"]["platform"],
|
||||
"vehicle_id": self._attrs["departure"]["vehicle"],
|
||||
"destination": self._attrs.departure.station,
|
||||
"direction": self._attrs.departure.direction.name,
|
||||
"platform_arriving": self._attrs.arrival.platform,
|
||||
"platform_departing": self._attrs.departure.platform,
|
||||
"vehicle_id": self._attrs.departure.vehicle,
|
||||
}
|
||||
|
||||
if canceled != 1:
|
||||
if not canceled:
|
||||
attrs["departure"] = f"In {departure} minutes"
|
||||
attrs["departure_minutes"] = departure
|
||||
attrs["canceled"] = False
|
||||
@ -364,14 +360,14 @@ class NMBSSensor(SensorEntity):
|
||||
attrs[ATTR_LONGITUDE] = self.station_coordinates[1]
|
||||
|
||||
if self.is_via_connection and not self._excl_vias:
|
||||
via = self._attrs["vias"]["via"][0]
|
||||
via = self._attrs.vias.via[0]
|
||||
|
||||
attrs["via"] = via["station"]
|
||||
attrs["via_arrival_platform"] = via["arrival"]["platform"]
|
||||
attrs["via_transfer_platform"] = via["departure"]["platform"]
|
||||
attrs["via"] = via.station
|
||||
attrs["via_arrival_platform"] = via.arrival.platform
|
||||
attrs["via_transfer_platform"] = via.departure.platform
|
||||
attrs["via_transfer_time"] = get_delay_in_minutes(
|
||||
via["timebetween"]
|
||||
) + get_delay_in_minutes(via["departure"]["delay"])
|
||||
via.timebetween
|
||||
) + get_delay_in_minutes(via.departure.delay)
|
||||
|
||||
if delay > 0:
|
||||
attrs["delay"] = f"{delay} minutes"
|
||||
@ -390,8 +386,8 @@ class NMBSSensor(SensorEntity):
|
||||
if self._state is None or not self._attrs:
|
||||
return []
|
||||
|
||||
latitude = float(self._attrs["departure"]["stationinfo"]["locationY"])
|
||||
longitude = float(self._attrs["departure"]["stationinfo"]["locationX"])
|
||||
latitude = float(self._attrs.departure.station_info.latitude)
|
||||
longitude = float(self._attrs.departure.station_info.longitude)
|
||||
return [latitude, longitude]
|
||||
|
||||
@property
|
||||
@ -400,24 +396,24 @@ class NMBSSensor(SensorEntity):
|
||||
if not self._attrs:
|
||||
return False
|
||||
|
||||
return "vias" in self._attrs and int(self._attrs["vias"]["number"]) > 0
|
||||
return self._attrs.vias is not None and len(self._attrs.vias) > 0
|
||||
|
||||
def update(self) -> None:
|
||||
async def async_update(self, **kwargs: Any) -> None:
|
||||
"""Set the state to the duration of a connection."""
|
||||
connections = self._api_client.get_connections(
|
||||
self._station_from["id"], self._station_to["id"]
|
||||
connections = await self._api_client.get_connections(
|
||||
self._station_from.id, self._station_to.id
|
||||
)
|
||||
|
||||
if connections == API_FAILURE:
|
||||
if connections is None:
|
||||
_LOGGER.warning("API failed in NMBSSensor")
|
||||
return
|
||||
|
||||
if not (connection := connections.get("connection")):
|
||||
if not (connection := connections.connections):
|
||||
_LOGGER.warning("API returned invalid connection: %r", connections)
|
||||
return
|
||||
|
||||
_LOGGER.debug("API returned connection: %r", connection)
|
||||
if int(connection[0]["departure"]["left"]) > 0:
|
||||
if connection[0].departure.left:
|
||||
next_connection = connection[1]
|
||||
else:
|
||||
next_connection = connection[0]
|
||||
@ -431,9 +427,9 @@ class NMBSSensor(SensorEntity):
|
||||
return
|
||||
|
||||
duration = get_ride_duration(
|
||||
next_connection["departure"]["time"],
|
||||
next_connection["arrival"]["time"],
|
||||
next_connection["departure"]["delay"],
|
||||
next_connection.departure.time,
|
||||
next_connection.arrival.time,
|
||||
next_connection.departure.delay,
|
||||
)
|
||||
|
||||
self._state = duration
|
||||
|
2
requirements_all.txt
generated
2
requirements_all.txt
generated
@ -2244,7 +2244,7 @@ pyqvrpro==0.52
|
||||
pyqwikswitch==0.93
|
||||
|
||||
# homeassistant.components.nmbs
|
||||
pyrail==0.0.3
|
||||
pyrail==0.4.1
|
||||
|
||||
# homeassistant.components.rainbird
|
||||
pyrainbird==6.0.1
|
||||
|
2
requirements_test_all.txt
generated
2
requirements_test_all.txt
generated
@ -1834,7 +1834,7 @@ pyps4-2ndscreen==1.3.1
|
||||
pyqwikswitch==0.93
|
||||
|
||||
# homeassistant.components.nmbs
|
||||
pyrail==0.0.3
|
||||
pyrail==0.4.1
|
||||
|
||||
# homeassistant.components.rainbird
|
||||
pyrainbird==6.0.1
|
||||
|
@ -1,20 +1 @@
|
||||
"""Tests for the NMBS integration."""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from tests.common import load_fixture
|
||||
|
||||
|
||||
def mock_api_unavailable() -> dict[str, Any]:
|
||||
"""Mock for unavailable api."""
|
||||
return -1
|
||||
|
||||
|
||||
def mock_station_response() -> dict[str, Any]:
|
||||
"""Mock for valid station response."""
|
||||
dummy_stations_response: dict[str, Any] = json.loads(
|
||||
load_fixture("stations.json", "nmbs")
|
||||
)
|
||||
|
||||
return dummy_stations_response
|
||||
|
@ -3,6 +3,7 @@
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from pyrail.models import StationsApiResponse
|
||||
import pytest
|
||||
|
||||
from homeassistant.components.nmbs.const import (
|
||||
@ -38,8 +39,8 @@ def mock_nmbs_client() -> Generator[AsyncMock]:
|
||||
),
|
||||
):
|
||||
client = mock_client.return_value
|
||||
client.get_stations.return_value = load_json_object_fixture(
|
||||
"stations.json", DOMAIN
|
||||
client.get_stations.return_value = StationsApiResponse.from_dict(
|
||||
load_json_object_fixture("stations.json", DOMAIN)
|
||||
)
|
||||
yield client
|
||||
|
||||
|
@ -142,7 +142,7 @@ async def test_unavailable_api(
|
||||
hass: HomeAssistant, mock_nmbs_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test starting a flow by user and api is unavailable."""
|
||||
mock_nmbs_client.get_stations.return_value = -1
|
||||
mock_nmbs_client.get_stations.return_value = None
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN,
|
||||
context={"source": config_entries.SOURCE_USER},
|
||||
@ -203,7 +203,7 @@ async def test_unavailable_api_import(
|
||||
hass: HomeAssistant, mock_nmbs_client: AsyncMock
|
||||
) -> None:
|
||||
"""Test starting a flow by import and api is unavailable."""
|
||||
mock_nmbs_client.get_stations.return_value = -1
|
||||
mock_nmbs_client.get_stations.return_value = None
|
||||
result = await hass.config_entries.flow.async_init(
|
||||
DOMAIN,
|
||||
context={"source": SOURCE_IMPORT},
|
||||
|
Loading…
x
Reference in New Issue
Block a user