Add typing/async to NMBS (#139002)

* Add typing/async to NMBS

* Fix tests

* Boolean fields

* Update homeassistant/components/nmbs/sensor.py

Co-authored-by: Jorim Tielemans <tielemans.jorim@gmail.com>

---------

Co-authored-by: Shay Levy <levyshay1@gmail.com>
Co-authored-by: Jorim Tielemans <tielemans.jorim@gmail.com>
This commit is contained in:
Simon Lamon 2025-03-02 17:36:37 +01:00 committed by GitHub
parent de4540c68e
commit 40099547ef
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 101 additions and 126 deletions

View File

@ -8,6 +8,7 @@ from homeassistant.config_entries import ConfigEntry
from homeassistant.const import Platform
from homeassistant.core import HomeAssistant
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.typing import ConfigType
from .const import DOMAIN
@ -22,13 +23,13 @@ CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
"""Set up the NMBS component."""
api_client = iRail()
api_client = iRail(session=async_get_clientsession(hass))
hass.data.setdefault(DOMAIN, {})
station_response = await hass.async_add_executor_job(api_client.get_stations)
if station_response == -1:
station_response = await api_client.get_stations()
if station_response is None:
return False
hass.data[DOMAIN] = station_response["station"]
hass.data[DOMAIN] = station_response.stations
return True

View File

@ -3,11 +3,13 @@
from typing import Any
from pyrail import iRail
from pyrail.models import StationDetails
import voluptuous as vol
from homeassistant.config_entries import ConfigFlow, ConfigFlowResult
from homeassistant.const import Platform
from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.selector import (
BooleanSelector,
SelectOptionDict,
@ -31,17 +33,15 @@ class NMBSConfigFlow(ConfigFlow, domain=DOMAIN):
def __init__(self) -> None:
"""Initialize."""
self.api_client = iRail()
self.stations: list[dict[str, Any]] = []
self.stations: list[StationDetails] = []
async def _fetch_stations(self) -> list[dict[str, Any]]:
async def _fetch_stations(self) -> list[StationDetails]:
"""Fetch the stations."""
stations_response = await self.hass.async_add_executor_job(
self.api_client.get_stations
)
if stations_response == -1:
api_client = iRail(session=async_get_clientsession(self.hass))
stations_response = await api_client.get_stations()
if stations_response is None:
raise CannotConnect("The API is currently unavailable.")
return stations_response["station"]
return stations_response.stations
async def _fetch_stations_choices(self) -> list[SelectOptionDict]:
"""Fetch the stations options."""
@ -50,7 +50,7 @@ class NMBSConfigFlow(ConfigFlow, domain=DOMAIN):
self.stations = await self._fetch_stations()
return [
SelectOptionDict(value=station["id"], label=station["standardname"])
SelectOptionDict(value=station.id, label=station.standard_name)
for station in self.stations
]
@ -72,12 +72,12 @@ class NMBSConfigFlow(ConfigFlow, domain=DOMAIN):
[station_from] = [
station
for station in self.stations
if station["id"] == user_input[CONF_STATION_FROM]
if station.id == user_input[CONF_STATION_FROM]
]
[station_to] = [
station
for station in self.stations
if station["id"] == user_input[CONF_STATION_TO]
if station.id == user_input[CONF_STATION_TO]
]
vias = "_excl_vias" if user_input.get(CONF_EXCLUDE_VIAS) else ""
await self.async_set_unique_id(
@ -85,7 +85,7 @@ class NMBSConfigFlow(ConfigFlow, domain=DOMAIN):
)
self._abort_if_unique_id_configured()
config_entry_name = f"Train from {station_from['standardname']} to {station_to['standardname']}"
config_entry_name = f"Train from {station_from.standard_name} to {station_to.standard_name}"
return self.async_create_entry(
title=config_entry_name,
data=user_input,
@ -127,18 +127,18 @@ class NMBSConfigFlow(ConfigFlow, domain=DOMAIN):
station_live = None
for station in self.stations:
if user_input[CONF_STATION_FROM] in (
station["standardname"],
station["name"],
station.standard_name,
station.name,
):
station_from = station
if user_input[CONF_STATION_TO] in (
station["standardname"],
station["name"],
station.standard_name,
station.name,
):
station_to = station
if CONF_STATION_LIVE in user_input and user_input[CONF_STATION_LIVE] in (
station["standardname"],
station["name"],
station.standard_name,
station.name,
):
station_live = station
@ -148,29 +148,29 @@ class NMBSConfigFlow(ConfigFlow, domain=DOMAIN):
return self.async_abort(reason="same_station")
# config flow uses id and not the standard name
user_input[CONF_STATION_FROM] = station_from["id"]
user_input[CONF_STATION_TO] = station_to["id"]
user_input[CONF_STATION_FROM] = station_from.id
user_input[CONF_STATION_TO] = station_to.id
if station_live:
user_input[CONF_STATION_LIVE] = station_live["id"]
user_input[CONF_STATION_LIVE] = station_live.id
entity_registry = er.async_get(self.hass)
prefix = "live"
vias = "_excl_vias" if user_input.get(CONF_EXCLUDE_VIAS, False) else ""
if entity_id := entity_registry.async_get_entity_id(
Platform.SENSOR,
DOMAIN,
f"{prefix}_{station_live['standardname']}_{station_from['standardname']}_{station_to['standardname']}",
f"{prefix}_{station_live.standard_name}_{station_from.standard_name}_{station_to.standard_name}",
):
new_unique_id = f"{DOMAIN}_{prefix}_{station_live['id']}_{station_from['id']}_{station_to['id']}{vias}"
new_unique_id = f"{DOMAIN}_{prefix}_{station_live.id}_{station_from.id}_{station_to.id}{vias}"
entity_registry.async_update_entity(
entity_id, new_unique_id=new_unique_id
)
if entity_id := entity_registry.async_get_entity_id(
Platform.SENSOR,
DOMAIN,
f"{prefix}_{station_live['name']}_{station_from['name']}_{station_to['name']}",
f"{prefix}_{station_live.name}_{station_from.name}_{station_to.name}",
):
new_unique_id = f"{DOMAIN}_{prefix}_{station_live['id']}_{station_from['id']}_{station_to['id']}{vias}"
new_unique_id = f"{DOMAIN}_{prefix}_{station_live.id}_{station_from.id}_{station_to.id}{vias}"
entity_registry.async_update_entity(
entity_id, new_unique_id=new_unique_id
)

View File

@ -19,11 +19,7 @@ CONF_SHOW_ON_MAP = "show_on_map"
def find_station_by_name(hass: HomeAssistant, station_name: str):
"""Find given station_name in the station list."""
return next(
(
s
for s in hass.data[DOMAIN]
if station_name in (s["standardname"], s["name"])
),
(s for s in hass.data[DOMAIN] if station_name in (s.standard_name, s.name)),
None,
)
@ -31,6 +27,6 @@ def find_station_by_name(hass: HomeAssistant, station_name: str):
def find_station(hass: HomeAssistant, station_name: str):
"""Find given station_id in the station list."""
return next(
(s for s in hass.data[DOMAIN] if station_name in s["id"]),
(s for s in hass.data[DOMAIN] if station_name in s.id),
None,
)

View File

@ -7,5 +7,5 @@
"iot_class": "cloud_polling",
"loggers": ["pyrail"],
"quality_scale": "legacy",
"requirements": ["pyrail==0.0.3"]
"requirements": ["pyrail==0.4.1"]
}

View File

@ -2,10 +2,12 @@
from __future__ import annotations
from datetime import datetime
import logging
from typing import Any
from pyrail import iRail
from pyrail.models import ConnectionDetails, LiveboardDeparture, StationDetails
import voluptuous as vol
from homeassistant.components.sensor import (
@ -23,6 +25,7 @@ from homeassistant.const import (
)
from homeassistant.core import DOMAIN as HOMEASSISTANT_DOMAIN, HomeAssistant
from homeassistant.helpers import config_validation as cv
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.entity_platform import (
AddConfigEntryEntitiesCallback,
AddEntitiesCallback,
@ -44,8 +47,6 @@ from .const import ( # noqa: F401
_LOGGER = logging.getLogger(__name__)
API_FAILURE = -1
DEFAULT_NAME = "NMBS"
DEFAULT_ICON = "mdi:train"
@ -63,12 +64,12 @@ PLATFORM_SCHEMA = SENSOR_PLATFORM_SCHEMA.extend(
)
def get_time_until(departure_time=None):
def get_time_until(departure_time: datetime | None = None):
"""Calculate the time between now and a train's departure time."""
if departure_time is None:
return 0
delta = dt_util.utc_from_timestamp(int(departure_time)) - dt_util.now()
delta = dt_util.as_utc(departure_time) - dt_util.utcnow()
return round(delta.total_seconds() / 60)
@ -77,11 +78,9 @@ def get_delay_in_minutes(delay=0):
return round(int(delay) / 60)
def get_ride_duration(departure_time, arrival_time, delay=0):
def get_ride_duration(departure_time: datetime, arrival_time: datetime, delay=0):
"""Calculate the total travel time in minutes."""
duration = dt_util.utc_from_timestamp(
int(arrival_time)
) - dt_util.utc_from_timestamp(int(departure_time))
duration = arrival_time - departure_time
duration_time = int(round(duration.total_seconds() / 60))
return duration_time + get_delay_in_minutes(delay)
@ -157,7 +156,7 @@ async def async_setup_entry(
async_add_entities: AddConfigEntryEntitiesCallback,
) -> None:
"""Set up NMBS sensor entities based on a config entry."""
api_client = iRail()
api_client = iRail(session=async_get_clientsession(hass))
name = config_entry.data.get(CONF_NAME, None)
show_on_map = config_entry.data.get(CONF_SHOW_ON_MAP, False)
@ -189,9 +188,9 @@ class NMBSLiveBoard(SensorEntity):
def __init__(
self,
api_client: iRail,
live_station: dict[str, Any],
station_from: dict[str, Any],
station_to: dict[str, Any],
live_station: StationDetails,
station_from: StationDetails,
station_to: StationDetails,
excl_vias: bool,
) -> None:
"""Initialize the sensor for getting liveboard data."""
@ -201,7 +200,8 @@ class NMBSLiveBoard(SensorEntity):
self._station_to = station_to
self._excl_vias = excl_vias
self._attrs: dict[str, Any] | None = {}
self._attrs: LiveboardDeparture | None = None
self._state: str | None = None
self.entity_registry_enabled_default = False
@ -209,22 +209,20 @@ class NMBSLiveBoard(SensorEntity):
@property
def name(self) -> str:
"""Return the sensor default name."""
return f"Trains in {self._station['standardname']}"
return f"Trains in {self._station.standard_name}"
@property
def unique_id(self) -> str:
"""Return the unique ID."""
unique_id = (
f"{self._station['id']}_{self._station_from['id']}_{self._station_to['id']}"
)
unique_id = f"{self._station.id}_{self._station_from.id}_{self._station_to.id}"
vias = "_excl_vias" if self._excl_vias else ""
return f"nmbs_live_{unique_id}{vias}"
@property
def icon(self) -> str:
"""Return the default icon or an alert icon if delays."""
if self._attrs and int(self._attrs["delay"]) > 0:
if self._attrs and int(self._attrs.delay) > 0:
return DEFAULT_ICON_ALERT
return DEFAULT_ICON
@ -240,15 +238,15 @@ class NMBSLiveBoard(SensorEntity):
if self._state is None or not self._attrs:
return None
delay = get_delay_in_minutes(self._attrs["delay"])
departure = get_time_until(self._attrs["time"])
delay = get_delay_in_minutes(self._attrs.delay)
departure = get_time_until(self._attrs.time)
attrs = {
"departure": f"In {departure} minutes",
"departure_minutes": departure,
"extra_train": int(self._attrs["isExtra"]) > 0,
"vehicle_id": self._attrs["vehicle"],
"monitored_station": self._station["standardname"],
"extra_train": self._attrs.is_extra,
"vehicle_id": self._attrs.vehicle,
"monitored_station": self._station.standard_name,
}
if delay > 0:
@ -257,28 +255,26 @@ class NMBSLiveBoard(SensorEntity):
return attrs
def update(self) -> None:
async def async_update(self, **kwargs: Any) -> None:
"""Set the state equal to the next departure."""
liveboard = self._api_client.get_liveboard(self._station["id"])
liveboard = await self._api_client.get_liveboard(self._station.id)
if liveboard == API_FAILURE:
if liveboard is None:
_LOGGER.warning("API failed in NMBSLiveBoard")
return
if not (departures := liveboard.get("departures")):
if not (departures := liveboard.departures):
_LOGGER.warning("API returned invalid departures: %r", liveboard)
return
_LOGGER.debug("API returned departures: %r", departures)
if departures["number"] == "0":
if len(departures) == 0:
# No trains are scheduled
return
next_departure = departures["departure"][0]
next_departure = departures[0]
self._attrs = next_departure
self._state = (
f"Track {next_departure['platform']} - {next_departure['station']}"
)
self._state = f"Track {next_departure.platform} - {next_departure.station}"
class NMBSSensor(SensorEntity):
@ -292,8 +288,8 @@ class NMBSSensor(SensorEntity):
api_client: iRail,
name: str,
show_on_map: bool,
station_from: dict[str, Any],
station_to: dict[str, Any],
station_from: StationDetails,
station_to: StationDetails,
excl_vias: bool,
) -> None:
"""Initialize the NMBS connection sensor."""
@ -304,13 +300,13 @@ class NMBSSensor(SensorEntity):
self._station_to = station_to
self._excl_vias = excl_vias
self._attrs: dict[str, Any] | None = {}
self._attrs: ConnectionDetails | None = None
self._state = None
@property
def unique_id(self) -> str:
"""Return the unique ID."""
unique_id = f"{self._station_from['id']}_{self._station_to['id']}"
unique_id = f"{self._station_from.id}_{self._station_to.id}"
vias = "_excl_vias" if self._excl_vias else ""
return f"nmbs_connection_{unique_id}{vias}"
@ -319,14 +315,14 @@ class NMBSSensor(SensorEntity):
def name(self) -> str:
"""Return the name of the sensor."""
if self._name is None:
return f"Train from {self._station_from['standardname']} to {self._station_to['standardname']}"
return f"Train from {self._station_from.standard_name} to {self._station_to.standard_name}"
return self._name
@property
def icon(self) -> str:
"""Return the sensor default icon or an alert icon if any delay."""
if self._attrs:
delay = get_delay_in_minutes(self._attrs["departure"]["delay"])
delay = get_delay_in_minutes(self._attrs.departure.delay)
if delay > 0:
return "mdi:alert-octagon"
@ -338,19 +334,19 @@ class NMBSSensor(SensorEntity):
if self._state is None or not self._attrs:
return None
delay = get_delay_in_minutes(self._attrs["departure"]["delay"])
departure = get_time_until(self._attrs["departure"]["time"])
canceled = int(self._attrs["departure"]["canceled"])
delay = get_delay_in_minutes(self._attrs.departure.delay)
departure = get_time_until(self._attrs.departure.time)
canceled = self._attrs.departure.canceled
attrs = {
"destination": self._attrs["departure"]["station"],
"direction": self._attrs["departure"]["direction"]["name"],
"platform_arriving": self._attrs["arrival"]["platform"],
"platform_departing": self._attrs["departure"]["platform"],
"vehicle_id": self._attrs["departure"]["vehicle"],
"destination": self._attrs.departure.station,
"direction": self._attrs.departure.direction.name,
"platform_arriving": self._attrs.arrival.platform,
"platform_departing": self._attrs.departure.platform,
"vehicle_id": self._attrs.departure.vehicle,
}
if canceled != 1:
if not canceled:
attrs["departure"] = f"In {departure} minutes"
attrs["departure_minutes"] = departure
attrs["canceled"] = False
@ -364,14 +360,14 @@ class NMBSSensor(SensorEntity):
attrs[ATTR_LONGITUDE] = self.station_coordinates[1]
if self.is_via_connection and not self._excl_vias:
via = self._attrs["vias"]["via"][0]
via = self._attrs.vias.via[0]
attrs["via"] = via["station"]
attrs["via_arrival_platform"] = via["arrival"]["platform"]
attrs["via_transfer_platform"] = via["departure"]["platform"]
attrs["via"] = via.station
attrs["via_arrival_platform"] = via.arrival.platform
attrs["via_transfer_platform"] = via.departure.platform
attrs["via_transfer_time"] = get_delay_in_minutes(
via["timebetween"]
) + get_delay_in_minutes(via["departure"]["delay"])
via.timebetween
) + get_delay_in_minutes(via.departure.delay)
if delay > 0:
attrs["delay"] = f"{delay} minutes"
@ -390,8 +386,8 @@ class NMBSSensor(SensorEntity):
if self._state is None or not self._attrs:
return []
latitude = float(self._attrs["departure"]["stationinfo"]["locationY"])
longitude = float(self._attrs["departure"]["stationinfo"]["locationX"])
latitude = float(self._attrs.departure.station_info.latitude)
longitude = float(self._attrs.departure.station_info.longitude)
return [latitude, longitude]
@property
@ -400,24 +396,24 @@ class NMBSSensor(SensorEntity):
if not self._attrs:
return False
return "vias" in self._attrs and int(self._attrs["vias"]["number"]) > 0
return self._attrs.vias is not None and len(self._attrs.vias) > 0
def update(self) -> None:
async def async_update(self, **kwargs: Any) -> None:
"""Set the state to the duration of a connection."""
connections = self._api_client.get_connections(
self._station_from["id"], self._station_to["id"]
connections = await self._api_client.get_connections(
self._station_from.id, self._station_to.id
)
if connections == API_FAILURE:
if connections is None:
_LOGGER.warning("API failed in NMBSSensor")
return
if not (connection := connections.get("connection")):
if not (connection := connections.connections):
_LOGGER.warning("API returned invalid connection: %r", connections)
return
_LOGGER.debug("API returned connection: %r", connection)
if int(connection[0]["departure"]["left"]) > 0:
if connection[0].departure.left:
next_connection = connection[1]
else:
next_connection = connection[0]
@ -431,9 +427,9 @@ class NMBSSensor(SensorEntity):
return
duration = get_ride_duration(
next_connection["departure"]["time"],
next_connection["arrival"]["time"],
next_connection["departure"]["delay"],
next_connection.departure.time,
next_connection.arrival.time,
next_connection.departure.delay,
)
self._state = duration

2
requirements_all.txt generated
View File

@ -2244,7 +2244,7 @@ pyqvrpro==0.52
pyqwikswitch==0.93
# homeassistant.components.nmbs
pyrail==0.0.3
pyrail==0.4.1
# homeassistant.components.rainbird
pyrainbird==6.0.1

View File

@ -1834,7 +1834,7 @@ pyps4-2ndscreen==1.3.1
pyqwikswitch==0.93
# homeassistant.components.nmbs
pyrail==0.0.3
pyrail==0.4.1
# homeassistant.components.rainbird
pyrainbird==6.0.1

View File

@ -1,20 +1 @@
"""Tests for the NMBS integration."""
import json
from typing import Any
from tests.common import load_fixture
def mock_api_unavailable() -> dict[str, Any]:
"""Mock for unavailable api."""
return -1
def mock_station_response() -> dict[str, Any]:
"""Mock for valid station response."""
dummy_stations_response: dict[str, Any] = json.loads(
load_fixture("stations.json", "nmbs")
)
return dummy_stations_response

View File

@ -3,6 +3,7 @@
from collections.abc import Generator
from unittest.mock import AsyncMock, patch
from pyrail.models import StationsApiResponse
import pytest
from homeassistant.components.nmbs.const import (
@ -38,8 +39,8 @@ def mock_nmbs_client() -> Generator[AsyncMock]:
),
):
client = mock_client.return_value
client.get_stations.return_value = load_json_object_fixture(
"stations.json", DOMAIN
client.get_stations.return_value = StationsApiResponse.from_dict(
load_json_object_fixture("stations.json", DOMAIN)
)
yield client

View File

@ -142,7 +142,7 @@ async def test_unavailable_api(
hass: HomeAssistant, mock_nmbs_client: AsyncMock
) -> None:
"""Test starting a flow by user and api is unavailable."""
mock_nmbs_client.get_stations.return_value = -1
mock_nmbs_client.get_stations.return_value = None
result = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": config_entries.SOURCE_USER},
@ -203,7 +203,7 @@ async def test_unavailable_api_import(
hass: HomeAssistant, mock_nmbs_client: AsyncMock
) -> None:
"""Test starting a flow by import and api is unavailable."""
mock_nmbs_client.get_stations.return_value = -1
mock_nmbs_client.get_stations.return_value = None
result = await hass.config_entries.flow.async_init(
DOMAIN,
context={"source": SOURCE_IMPORT},