From 40099547ef6dbd59caf811cafb85df276ba17ccd Mon Sep 17 00:00:00 2001 From: Simon Lamon <32477463+silamon@users.noreply.github.com> Date: Sun, 2 Mar 2025 17:36:37 +0100 Subject: [PATCH] Add typing/async to NMBS (#139002) * Add typing/async to NMBS * Fix tests * Boolean fields * Update homeassistant/components/nmbs/sensor.py Co-authored-by: Jorim Tielemans --------- Co-authored-by: Shay Levy Co-authored-by: Jorim Tielemans --- homeassistant/components/nmbs/__init__.py | 9 +- homeassistant/components/nmbs/config_flow.py | 50 ++++---- homeassistant/components/nmbs/const.py | 8 +- homeassistant/components/nmbs/manifest.json | 2 +- homeassistant/components/nmbs/sensor.py | 126 +++++++++---------- requirements_all.txt | 2 +- requirements_test_all.txt | 2 +- tests/components/nmbs/__init__.py | 19 --- tests/components/nmbs/conftest.py | 5 +- tests/components/nmbs/test_config_flow.py | 4 +- 10 files changed, 101 insertions(+), 126 deletions(-) diff --git a/homeassistant/components/nmbs/__init__.py b/homeassistant/components/nmbs/__init__.py index 7d06baf37b6..4a2783143ca 100644 --- a/homeassistant/components/nmbs/__init__.py +++ b/homeassistant/components/nmbs/__init__.py @@ -8,6 +8,7 @@ from homeassistant.config_entries import ConfigEntry from homeassistant.const import Platform from homeassistant.core import HomeAssistant from homeassistant.helpers import config_validation as cv +from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.typing import ConfigType from .const import DOMAIN @@ -22,13 +23,13 @@ CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN) async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool: """Set up the NMBS component.""" - api_client = iRail() + api_client = iRail(session=async_get_clientsession(hass)) hass.data.setdefault(DOMAIN, {}) - station_response = await hass.async_add_executor_job(api_client.get_stations) - if station_response == -1: + station_response = await api_client.get_stations() + if station_response is None: return False - hass.data[DOMAIN] = station_response["station"] + hass.data[DOMAIN] = station_response.stations return True diff --git a/homeassistant/components/nmbs/config_flow.py b/homeassistant/components/nmbs/config_flow.py index e45b2d9adeb..60ab015e22b 100644 --- a/homeassistant/components/nmbs/config_flow.py +++ b/homeassistant/components/nmbs/config_flow.py @@ -3,11 +3,13 @@ from typing import Any from pyrail import iRail +from pyrail.models import StationDetails import voluptuous as vol from homeassistant.config_entries import ConfigFlow, ConfigFlowResult from homeassistant.const import Platform from homeassistant.helpers import entity_registry as er +from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.selector import ( BooleanSelector, SelectOptionDict, @@ -31,17 +33,15 @@ class NMBSConfigFlow(ConfigFlow, domain=DOMAIN): def __init__(self) -> None: """Initialize.""" - self.api_client = iRail() - self.stations: list[dict[str, Any]] = [] + self.stations: list[StationDetails] = [] - async def _fetch_stations(self) -> list[dict[str, Any]]: + async def _fetch_stations(self) -> list[StationDetails]: """Fetch the stations.""" - stations_response = await self.hass.async_add_executor_job( - self.api_client.get_stations - ) - if stations_response == -1: + api_client = iRail(session=async_get_clientsession(self.hass)) + stations_response = await api_client.get_stations() + if stations_response is None: raise CannotConnect("The API is currently unavailable.") - return stations_response["station"] + return stations_response.stations async def _fetch_stations_choices(self) -> list[SelectOptionDict]: """Fetch the stations options.""" @@ -50,7 +50,7 @@ class NMBSConfigFlow(ConfigFlow, domain=DOMAIN): self.stations = await self._fetch_stations() return [ - SelectOptionDict(value=station["id"], label=station["standardname"]) + SelectOptionDict(value=station.id, label=station.standard_name) for station in self.stations ] @@ -72,12 +72,12 @@ class NMBSConfigFlow(ConfigFlow, domain=DOMAIN): [station_from] = [ station for station in self.stations - if station["id"] == user_input[CONF_STATION_FROM] + if station.id == user_input[CONF_STATION_FROM] ] [station_to] = [ station for station in self.stations - if station["id"] == user_input[CONF_STATION_TO] + if station.id == user_input[CONF_STATION_TO] ] vias = "_excl_vias" if user_input.get(CONF_EXCLUDE_VIAS) else "" await self.async_set_unique_id( @@ -85,7 +85,7 @@ class NMBSConfigFlow(ConfigFlow, domain=DOMAIN): ) self._abort_if_unique_id_configured() - config_entry_name = f"Train from {station_from['standardname']} to {station_to['standardname']}" + config_entry_name = f"Train from {station_from.standard_name} to {station_to.standard_name}" return self.async_create_entry( title=config_entry_name, data=user_input, @@ -127,18 +127,18 @@ class NMBSConfigFlow(ConfigFlow, domain=DOMAIN): station_live = None for station in self.stations: if user_input[CONF_STATION_FROM] in ( - station["standardname"], - station["name"], + station.standard_name, + station.name, ): station_from = station if user_input[CONF_STATION_TO] in ( - station["standardname"], - station["name"], + station.standard_name, + station.name, ): station_to = station if CONF_STATION_LIVE in user_input and user_input[CONF_STATION_LIVE] in ( - station["standardname"], - station["name"], + station.standard_name, + station.name, ): station_live = station @@ -148,29 +148,29 @@ class NMBSConfigFlow(ConfigFlow, domain=DOMAIN): return self.async_abort(reason="same_station") # config flow uses id and not the standard name - user_input[CONF_STATION_FROM] = station_from["id"] - user_input[CONF_STATION_TO] = station_to["id"] + user_input[CONF_STATION_FROM] = station_from.id + user_input[CONF_STATION_TO] = station_to.id if station_live: - user_input[CONF_STATION_LIVE] = station_live["id"] + user_input[CONF_STATION_LIVE] = station_live.id entity_registry = er.async_get(self.hass) prefix = "live" vias = "_excl_vias" if user_input.get(CONF_EXCLUDE_VIAS, False) else "" if entity_id := entity_registry.async_get_entity_id( Platform.SENSOR, DOMAIN, - f"{prefix}_{station_live['standardname']}_{station_from['standardname']}_{station_to['standardname']}", + f"{prefix}_{station_live.standard_name}_{station_from.standard_name}_{station_to.standard_name}", ): - new_unique_id = f"{DOMAIN}_{prefix}_{station_live['id']}_{station_from['id']}_{station_to['id']}{vias}" + new_unique_id = f"{DOMAIN}_{prefix}_{station_live.id}_{station_from.id}_{station_to.id}{vias}" entity_registry.async_update_entity( entity_id, new_unique_id=new_unique_id ) if entity_id := entity_registry.async_get_entity_id( Platform.SENSOR, DOMAIN, - f"{prefix}_{station_live['name']}_{station_from['name']}_{station_to['name']}", + f"{prefix}_{station_live.name}_{station_from.name}_{station_to.name}", ): - new_unique_id = f"{DOMAIN}_{prefix}_{station_live['id']}_{station_from['id']}_{station_to['id']}{vias}" + new_unique_id = f"{DOMAIN}_{prefix}_{station_live.id}_{station_from.id}_{station_to.id}{vias}" entity_registry.async_update_entity( entity_id, new_unique_id=new_unique_id ) diff --git a/homeassistant/components/nmbs/const.py b/homeassistant/components/nmbs/const.py index fddb7365501..04c8beb327d 100644 --- a/homeassistant/components/nmbs/const.py +++ b/homeassistant/components/nmbs/const.py @@ -19,11 +19,7 @@ CONF_SHOW_ON_MAP = "show_on_map" def find_station_by_name(hass: HomeAssistant, station_name: str): """Find given station_name in the station list.""" return next( - ( - s - for s in hass.data[DOMAIN] - if station_name in (s["standardname"], s["name"]) - ), + (s for s in hass.data[DOMAIN] if station_name in (s.standard_name, s.name)), None, ) @@ -31,6 +27,6 @@ def find_station_by_name(hass: HomeAssistant, station_name: str): def find_station(hass: HomeAssistant, station_name: str): """Find given station_id in the station list.""" return next( - (s for s in hass.data[DOMAIN] if station_name in s["id"]), + (s for s in hass.data[DOMAIN] if station_name in s.id), None, ) diff --git a/homeassistant/components/nmbs/manifest.json b/homeassistant/components/nmbs/manifest.json index 9016eff11f8..37ff9429a54 100644 --- a/homeassistant/components/nmbs/manifest.json +++ b/homeassistant/components/nmbs/manifest.json @@ -7,5 +7,5 @@ "iot_class": "cloud_polling", "loggers": ["pyrail"], "quality_scale": "legacy", - "requirements": ["pyrail==0.0.3"] + "requirements": ["pyrail==0.4.1"] } diff --git a/homeassistant/components/nmbs/sensor.py b/homeassistant/components/nmbs/sensor.py index c6dea2d0843..822b0236dd0 100644 --- a/homeassistant/components/nmbs/sensor.py +++ b/homeassistant/components/nmbs/sensor.py @@ -2,10 +2,12 @@ from __future__ import annotations +from datetime import datetime import logging from typing import Any from pyrail import iRail +from pyrail.models import ConnectionDetails, LiveboardDeparture, StationDetails import voluptuous as vol from homeassistant.components.sensor import ( @@ -23,6 +25,7 @@ from homeassistant.const import ( ) from homeassistant.core import DOMAIN as HOMEASSISTANT_DOMAIN, HomeAssistant from homeassistant.helpers import config_validation as cv +from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.entity_platform import ( AddConfigEntryEntitiesCallback, AddEntitiesCallback, @@ -44,8 +47,6 @@ from .const import ( # noqa: F401 _LOGGER = logging.getLogger(__name__) -API_FAILURE = -1 - DEFAULT_NAME = "NMBS" DEFAULT_ICON = "mdi:train" @@ -63,12 +64,12 @@ PLATFORM_SCHEMA = SENSOR_PLATFORM_SCHEMA.extend( ) -def get_time_until(departure_time=None): +def get_time_until(departure_time: datetime | None = None): """Calculate the time between now and a train's departure time.""" if departure_time is None: return 0 - delta = dt_util.utc_from_timestamp(int(departure_time)) - dt_util.now() + delta = dt_util.as_utc(departure_time) - dt_util.utcnow() return round(delta.total_seconds() / 60) @@ -77,11 +78,9 @@ def get_delay_in_minutes(delay=0): return round(int(delay) / 60) -def get_ride_duration(departure_time, arrival_time, delay=0): +def get_ride_duration(departure_time: datetime, arrival_time: datetime, delay=0): """Calculate the total travel time in minutes.""" - duration = dt_util.utc_from_timestamp( - int(arrival_time) - ) - dt_util.utc_from_timestamp(int(departure_time)) + duration = arrival_time - departure_time duration_time = int(round(duration.total_seconds() / 60)) return duration_time + get_delay_in_minutes(delay) @@ -157,7 +156,7 @@ async def async_setup_entry( async_add_entities: AddConfigEntryEntitiesCallback, ) -> None: """Set up NMBS sensor entities based on a config entry.""" - api_client = iRail() + api_client = iRail(session=async_get_clientsession(hass)) name = config_entry.data.get(CONF_NAME, None) show_on_map = config_entry.data.get(CONF_SHOW_ON_MAP, False) @@ -189,9 +188,9 @@ class NMBSLiveBoard(SensorEntity): def __init__( self, api_client: iRail, - live_station: dict[str, Any], - station_from: dict[str, Any], - station_to: dict[str, Any], + live_station: StationDetails, + station_from: StationDetails, + station_to: StationDetails, excl_vias: bool, ) -> None: """Initialize the sensor for getting liveboard data.""" @@ -201,7 +200,8 @@ class NMBSLiveBoard(SensorEntity): self._station_to = station_to self._excl_vias = excl_vias - self._attrs: dict[str, Any] | None = {} + self._attrs: LiveboardDeparture | None = None + self._state: str | None = None self.entity_registry_enabled_default = False @@ -209,22 +209,20 @@ class NMBSLiveBoard(SensorEntity): @property def name(self) -> str: """Return the sensor default name.""" - return f"Trains in {self._station['standardname']}" + return f"Trains in {self._station.standard_name}" @property def unique_id(self) -> str: """Return the unique ID.""" - unique_id = ( - f"{self._station['id']}_{self._station_from['id']}_{self._station_to['id']}" - ) + unique_id = f"{self._station.id}_{self._station_from.id}_{self._station_to.id}" vias = "_excl_vias" if self._excl_vias else "" return f"nmbs_live_{unique_id}{vias}" @property def icon(self) -> str: """Return the default icon or an alert icon if delays.""" - if self._attrs and int(self._attrs["delay"]) > 0: + if self._attrs and int(self._attrs.delay) > 0: return DEFAULT_ICON_ALERT return DEFAULT_ICON @@ -240,15 +238,15 @@ class NMBSLiveBoard(SensorEntity): if self._state is None or not self._attrs: return None - delay = get_delay_in_minutes(self._attrs["delay"]) - departure = get_time_until(self._attrs["time"]) + delay = get_delay_in_minutes(self._attrs.delay) + departure = get_time_until(self._attrs.time) attrs = { "departure": f"In {departure} minutes", "departure_minutes": departure, - "extra_train": int(self._attrs["isExtra"]) > 0, - "vehicle_id": self._attrs["vehicle"], - "monitored_station": self._station["standardname"], + "extra_train": self._attrs.is_extra, + "vehicle_id": self._attrs.vehicle, + "monitored_station": self._station.standard_name, } if delay > 0: @@ -257,28 +255,26 @@ class NMBSLiveBoard(SensorEntity): return attrs - def update(self) -> None: + async def async_update(self, **kwargs: Any) -> None: """Set the state equal to the next departure.""" - liveboard = self._api_client.get_liveboard(self._station["id"]) + liveboard = await self._api_client.get_liveboard(self._station.id) - if liveboard == API_FAILURE: + if liveboard is None: _LOGGER.warning("API failed in NMBSLiveBoard") return - if not (departures := liveboard.get("departures")): + if not (departures := liveboard.departures): _LOGGER.warning("API returned invalid departures: %r", liveboard) return _LOGGER.debug("API returned departures: %r", departures) - if departures["number"] == "0": + if len(departures) == 0: # No trains are scheduled return - next_departure = departures["departure"][0] + next_departure = departures[0] self._attrs = next_departure - self._state = ( - f"Track {next_departure['platform']} - {next_departure['station']}" - ) + self._state = f"Track {next_departure.platform} - {next_departure.station}" class NMBSSensor(SensorEntity): @@ -292,8 +288,8 @@ class NMBSSensor(SensorEntity): api_client: iRail, name: str, show_on_map: bool, - station_from: dict[str, Any], - station_to: dict[str, Any], + station_from: StationDetails, + station_to: StationDetails, excl_vias: bool, ) -> None: """Initialize the NMBS connection sensor.""" @@ -304,13 +300,13 @@ class NMBSSensor(SensorEntity): self._station_to = station_to self._excl_vias = excl_vias - self._attrs: dict[str, Any] | None = {} + self._attrs: ConnectionDetails | None = None self._state = None @property def unique_id(self) -> str: """Return the unique ID.""" - unique_id = f"{self._station_from['id']}_{self._station_to['id']}" + unique_id = f"{self._station_from.id}_{self._station_to.id}" vias = "_excl_vias" if self._excl_vias else "" return f"nmbs_connection_{unique_id}{vias}" @@ -319,14 +315,14 @@ class NMBSSensor(SensorEntity): def name(self) -> str: """Return the name of the sensor.""" if self._name is None: - return f"Train from {self._station_from['standardname']} to {self._station_to['standardname']}" + return f"Train from {self._station_from.standard_name} to {self._station_to.standard_name}" return self._name @property def icon(self) -> str: """Return the sensor default icon or an alert icon if any delay.""" if self._attrs: - delay = get_delay_in_minutes(self._attrs["departure"]["delay"]) + delay = get_delay_in_minutes(self._attrs.departure.delay) if delay > 0: return "mdi:alert-octagon" @@ -338,19 +334,19 @@ class NMBSSensor(SensorEntity): if self._state is None or not self._attrs: return None - delay = get_delay_in_minutes(self._attrs["departure"]["delay"]) - departure = get_time_until(self._attrs["departure"]["time"]) - canceled = int(self._attrs["departure"]["canceled"]) + delay = get_delay_in_minutes(self._attrs.departure.delay) + departure = get_time_until(self._attrs.departure.time) + canceled = self._attrs.departure.canceled attrs = { - "destination": self._attrs["departure"]["station"], - "direction": self._attrs["departure"]["direction"]["name"], - "platform_arriving": self._attrs["arrival"]["platform"], - "platform_departing": self._attrs["departure"]["platform"], - "vehicle_id": self._attrs["departure"]["vehicle"], + "destination": self._attrs.departure.station, + "direction": self._attrs.departure.direction.name, + "platform_arriving": self._attrs.arrival.platform, + "platform_departing": self._attrs.departure.platform, + "vehicle_id": self._attrs.departure.vehicle, } - if canceled != 1: + if not canceled: attrs["departure"] = f"In {departure} minutes" attrs["departure_minutes"] = departure attrs["canceled"] = False @@ -364,14 +360,14 @@ class NMBSSensor(SensorEntity): attrs[ATTR_LONGITUDE] = self.station_coordinates[1] if self.is_via_connection and not self._excl_vias: - via = self._attrs["vias"]["via"][0] + via = self._attrs.vias.via[0] - attrs["via"] = via["station"] - attrs["via_arrival_platform"] = via["arrival"]["platform"] - attrs["via_transfer_platform"] = via["departure"]["platform"] + attrs["via"] = via.station + attrs["via_arrival_platform"] = via.arrival.platform + attrs["via_transfer_platform"] = via.departure.platform attrs["via_transfer_time"] = get_delay_in_minutes( - via["timebetween"] - ) + get_delay_in_minutes(via["departure"]["delay"]) + via.timebetween + ) + get_delay_in_minutes(via.departure.delay) if delay > 0: attrs["delay"] = f"{delay} minutes" @@ -390,8 +386,8 @@ class NMBSSensor(SensorEntity): if self._state is None or not self._attrs: return [] - latitude = float(self._attrs["departure"]["stationinfo"]["locationY"]) - longitude = float(self._attrs["departure"]["stationinfo"]["locationX"]) + latitude = float(self._attrs.departure.station_info.latitude) + longitude = float(self._attrs.departure.station_info.longitude) return [latitude, longitude] @property @@ -400,24 +396,24 @@ class NMBSSensor(SensorEntity): if not self._attrs: return False - return "vias" in self._attrs and int(self._attrs["vias"]["number"]) > 0 + return self._attrs.vias is not None and len(self._attrs.vias) > 0 - def update(self) -> None: + async def async_update(self, **kwargs: Any) -> None: """Set the state to the duration of a connection.""" - connections = self._api_client.get_connections( - self._station_from["id"], self._station_to["id"] + connections = await self._api_client.get_connections( + self._station_from.id, self._station_to.id ) - if connections == API_FAILURE: + if connections is None: _LOGGER.warning("API failed in NMBSSensor") return - if not (connection := connections.get("connection")): + if not (connection := connections.connections): _LOGGER.warning("API returned invalid connection: %r", connections) return _LOGGER.debug("API returned connection: %r", connection) - if int(connection[0]["departure"]["left"]) > 0: + if connection[0].departure.left: next_connection = connection[1] else: next_connection = connection[0] @@ -431,9 +427,9 @@ class NMBSSensor(SensorEntity): return duration = get_ride_duration( - next_connection["departure"]["time"], - next_connection["arrival"]["time"], - next_connection["departure"]["delay"], + next_connection.departure.time, + next_connection.arrival.time, + next_connection.departure.delay, ) self._state = duration diff --git a/requirements_all.txt b/requirements_all.txt index 696aef8b03b..5d274a3ba6a 100644 --- a/requirements_all.txt +++ b/requirements_all.txt @@ -2244,7 +2244,7 @@ pyqvrpro==0.52 pyqwikswitch==0.93 # homeassistant.components.nmbs -pyrail==0.0.3 +pyrail==0.4.1 # homeassistant.components.rainbird pyrainbird==6.0.1 diff --git a/requirements_test_all.txt b/requirements_test_all.txt index b9509b7fac3..19e143e3975 100644 --- a/requirements_test_all.txt +++ b/requirements_test_all.txt @@ -1834,7 +1834,7 @@ pyps4-2ndscreen==1.3.1 pyqwikswitch==0.93 # homeassistant.components.nmbs -pyrail==0.0.3 +pyrail==0.4.1 # homeassistant.components.rainbird pyrainbird==6.0.1 diff --git a/tests/components/nmbs/__init__.py b/tests/components/nmbs/__init__.py index 91226950aba..3d284e5bb77 100644 --- a/tests/components/nmbs/__init__.py +++ b/tests/components/nmbs/__init__.py @@ -1,20 +1 @@ """Tests for the NMBS integration.""" - -import json -from typing import Any - -from tests.common import load_fixture - - -def mock_api_unavailable() -> dict[str, Any]: - """Mock for unavailable api.""" - return -1 - - -def mock_station_response() -> dict[str, Any]: - """Mock for valid station response.""" - dummy_stations_response: dict[str, Any] = json.loads( - load_fixture("stations.json", "nmbs") - ) - - return dummy_stations_response diff --git a/tests/components/nmbs/conftest.py b/tests/components/nmbs/conftest.py index 69200fc4c98..a39334ba62c 100644 --- a/tests/components/nmbs/conftest.py +++ b/tests/components/nmbs/conftest.py @@ -3,6 +3,7 @@ from collections.abc import Generator from unittest.mock import AsyncMock, patch +from pyrail.models import StationsApiResponse import pytest from homeassistant.components.nmbs.const import ( @@ -38,8 +39,8 @@ def mock_nmbs_client() -> Generator[AsyncMock]: ), ): client = mock_client.return_value - client.get_stations.return_value = load_json_object_fixture( - "stations.json", DOMAIN + client.get_stations.return_value = StationsApiResponse.from_dict( + load_json_object_fixture("stations.json", DOMAIN) ) yield client diff --git a/tests/components/nmbs/test_config_flow.py b/tests/components/nmbs/test_config_flow.py index ff4c5bdf72a..7e0f087607b 100644 --- a/tests/components/nmbs/test_config_flow.py +++ b/tests/components/nmbs/test_config_flow.py @@ -142,7 +142,7 @@ async def test_unavailable_api( hass: HomeAssistant, mock_nmbs_client: AsyncMock ) -> None: """Test starting a flow by user and api is unavailable.""" - mock_nmbs_client.get_stations.return_value = -1 + mock_nmbs_client.get_stations.return_value = None result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": config_entries.SOURCE_USER}, @@ -203,7 +203,7 @@ async def test_unavailable_api_import( hass: HomeAssistant, mock_nmbs_client: AsyncMock ) -> None: """Test starting a flow by import and api is unavailable.""" - mock_nmbs_client.get_stations.return_value = -1 + mock_nmbs_client.get_stations.return_value = None result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": SOURCE_IMPORT},