mirror of
https://github.com/home-assistant/core.git
synced 2025-07-17 10:17:09 +00:00
Add typing/async to NMBS (#139002)
* Add typing/async to NMBS * Fix tests * Boolean fields * Update homeassistant/components/nmbs/sensor.py Co-authored-by: Jorim Tielemans <tielemans.jorim@gmail.com> --------- Co-authored-by: Shay Levy <levyshay1@gmail.com> Co-authored-by: Jorim Tielemans <tielemans.jorim@gmail.com>
This commit is contained in:
parent
de4540c68e
commit
40099547ef
@ -8,6 +8,7 @@ from homeassistant.config_entries import ConfigEntry
|
|||||||
from homeassistant.const import Platform
|
from homeassistant.const import Platform
|
||||||
from homeassistant.core import HomeAssistant
|
from homeassistant.core import HomeAssistant
|
||||||
from homeassistant.helpers import config_validation as cv
|
from homeassistant.helpers import config_validation as cv
|
||||||
|
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||||
from homeassistant.helpers.typing import ConfigType
|
from homeassistant.helpers.typing import ConfigType
|
||||||
|
|
||||||
from .const import DOMAIN
|
from .const import DOMAIN
|
||||||
@ -22,13 +23,13 @@ CONFIG_SCHEMA = cv.empty_config_schema(DOMAIN)
|
|||||||
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
async def async_setup(hass: HomeAssistant, config: ConfigType) -> bool:
|
||||||
"""Set up the NMBS component."""
|
"""Set up the NMBS component."""
|
||||||
|
|
||||||
api_client = iRail()
|
api_client = iRail(session=async_get_clientsession(hass))
|
||||||
|
|
||||||
hass.data.setdefault(DOMAIN, {})
|
hass.data.setdefault(DOMAIN, {})
|
||||||
station_response = await hass.async_add_executor_job(api_client.get_stations)
|
station_response = await api_client.get_stations()
|
||||||
if station_response == -1:
|
if station_response is None:
|
||||||
return False
|
return False
|
||||||
hass.data[DOMAIN] = station_response["station"]
|
hass.data[DOMAIN] = station_response.stations
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -3,11 +3,13 @@
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pyrail import iRail
|
from pyrail import iRail
|
||||||
|
from pyrail.models import StationDetails
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.config_entries import ConfigFlow, ConfigFlowResult
|
from homeassistant.config_entries import ConfigFlow, ConfigFlowResult
|
||||||
from homeassistant.const import Platform
|
from homeassistant.const import Platform
|
||||||
from homeassistant.helpers import entity_registry as er
|
from homeassistant.helpers import entity_registry as er
|
||||||
|
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||||
from homeassistant.helpers.selector import (
|
from homeassistant.helpers.selector import (
|
||||||
BooleanSelector,
|
BooleanSelector,
|
||||||
SelectOptionDict,
|
SelectOptionDict,
|
||||||
@ -31,17 +33,15 @@ class NMBSConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
"""Initialize."""
|
"""Initialize."""
|
||||||
self.api_client = iRail()
|
self.stations: list[StationDetails] = []
|
||||||
self.stations: list[dict[str, Any]] = []
|
|
||||||
|
|
||||||
async def _fetch_stations(self) -> list[dict[str, Any]]:
|
async def _fetch_stations(self) -> list[StationDetails]:
|
||||||
"""Fetch the stations."""
|
"""Fetch the stations."""
|
||||||
stations_response = await self.hass.async_add_executor_job(
|
api_client = iRail(session=async_get_clientsession(self.hass))
|
||||||
self.api_client.get_stations
|
stations_response = await api_client.get_stations()
|
||||||
)
|
if stations_response is None:
|
||||||
if stations_response == -1:
|
|
||||||
raise CannotConnect("The API is currently unavailable.")
|
raise CannotConnect("The API is currently unavailable.")
|
||||||
return stations_response["station"]
|
return stations_response.stations
|
||||||
|
|
||||||
async def _fetch_stations_choices(self) -> list[SelectOptionDict]:
|
async def _fetch_stations_choices(self) -> list[SelectOptionDict]:
|
||||||
"""Fetch the stations options."""
|
"""Fetch the stations options."""
|
||||||
@ -50,7 +50,7 @@ class NMBSConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||||||
self.stations = await self._fetch_stations()
|
self.stations = await self._fetch_stations()
|
||||||
|
|
||||||
return [
|
return [
|
||||||
SelectOptionDict(value=station["id"], label=station["standardname"])
|
SelectOptionDict(value=station.id, label=station.standard_name)
|
||||||
for station in self.stations
|
for station in self.stations
|
||||||
]
|
]
|
||||||
|
|
||||||
@ -72,12 +72,12 @@ class NMBSConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||||||
[station_from] = [
|
[station_from] = [
|
||||||
station
|
station
|
||||||
for station in self.stations
|
for station in self.stations
|
||||||
if station["id"] == user_input[CONF_STATION_FROM]
|
if station.id == user_input[CONF_STATION_FROM]
|
||||||
]
|
]
|
||||||
[station_to] = [
|
[station_to] = [
|
||||||
station
|
station
|
||||||
for station in self.stations
|
for station in self.stations
|
||||||
if station["id"] == user_input[CONF_STATION_TO]
|
if station.id == user_input[CONF_STATION_TO]
|
||||||
]
|
]
|
||||||
vias = "_excl_vias" if user_input.get(CONF_EXCLUDE_VIAS) else ""
|
vias = "_excl_vias" if user_input.get(CONF_EXCLUDE_VIAS) else ""
|
||||||
await self.async_set_unique_id(
|
await self.async_set_unique_id(
|
||||||
@ -85,7 +85,7 @@ class NMBSConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||||||
)
|
)
|
||||||
self._abort_if_unique_id_configured()
|
self._abort_if_unique_id_configured()
|
||||||
|
|
||||||
config_entry_name = f"Train from {station_from['standardname']} to {station_to['standardname']}"
|
config_entry_name = f"Train from {station_from.standard_name} to {station_to.standard_name}"
|
||||||
return self.async_create_entry(
|
return self.async_create_entry(
|
||||||
title=config_entry_name,
|
title=config_entry_name,
|
||||||
data=user_input,
|
data=user_input,
|
||||||
@ -127,18 +127,18 @@ class NMBSConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||||||
station_live = None
|
station_live = None
|
||||||
for station in self.stations:
|
for station in self.stations:
|
||||||
if user_input[CONF_STATION_FROM] in (
|
if user_input[CONF_STATION_FROM] in (
|
||||||
station["standardname"],
|
station.standard_name,
|
||||||
station["name"],
|
station.name,
|
||||||
):
|
):
|
||||||
station_from = station
|
station_from = station
|
||||||
if user_input[CONF_STATION_TO] in (
|
if user_input[CONF_STATION_TO] in (
|
||||||
station["standardname"],
|
station.standard_name,
|
||||||
station["name"],
|
station.name,
|
||||||
):
|
):
|
||||||
station_to = station
|
station_to = station
|
||||||
if CONF_STATION_LIVE in user_input and user_input[CONF_STATION_LIVE] in (
|
if CONF_STATION_LIVE in user_input and user_input[CONF_STATION_LIVE] in (
|
||||||
station["standardname"],
|
station.standard_name,
|
||||||
station["name"],
|
station.name,
|
||||||
):
|
):
|
||||||
station_live = station
|
station_live = station
|
||||||
|
|
||||||
@ -148,29 +148,29 @@ class NMBSConfigFlow(ConfigFlow, domain=DOMAIN):
|
|||||||
return self.async_abort(reason="same_station")
|
return self.async_abort(reason="same_station")
|
||||||
|
|
||||||
# config flow uses id and not the standard name
|
# config flow uses id and not the standard name
|
||||||
user_input[CONF_STATION_FROM] = station_from["id"]
|
user_input[CONF_STATION_FROM] = station_from.id
|
||||||
user_input[CONF_STATION_TO] = station_to["id"]
|
user_input[CONF_STATION_TO] = station_to.id
|
||||||
|
|
||||||
if station_live:
|
if station_live:
|
||||||
user_input[CONF_STATION_LIVE] = station_live["id"]
|
user_input[CONF_STATION_LIVE] = station_live.id
|
||||||
entity_registry = er.async_get(self.hass)
|
entity_registry = er.async_get(self.hass)
|
||||||
prefix = "live"
|
prefix = "live"
|
||||||
vias = "_excl_vias" if user_input.get(CONF_EXCLUDE_VIAS, False) else ""
|
vias = "_excl_vias" if user_input.get(CONF_EXCLUDE_VIAS, False) else ""
|
||||||
if entity_id := entity_registry.async_get_entity_id(
|
if entity_id := entity_registry.async_get_entity_id(
|
||||||
Platform.SENSOR,
|
Platform.SENSOR,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
f"{prefix}_{station_live['standardname']}_{station_from['standardname']}_{station_to['standardname']}",
|
f"{prefix}_{station_live.standard_name}_{station_from.standard_name}_{station_to.standard_name}",
|
||||||
):
|
):
|
||||||
new_unique_id = f"{DOMAIN}_{prefix}_{station_live['id']}_{station_from['id']}_{station_to['id']}{vias}"
|
new_unique_id = f"{DOMAIN}_{prefix}_{station_live.id}_{station_from.id}_{station_to.id}{vias}"
|
||||||
entity_registry.async_update_entity(
|
entity_registry.async_update_entity(
|
||||||
entity_id, new_unique_id=new_unique_id
|
entity_id, new_unique_id=new_unique_id
|
||||||
)
|
)
|
||||||
if entity_id := entity_registry.async_get_entity_id(
|
if entity_id := entity_registry.async_get_entity_id(
|
||||||
Platform.SENSOR,
|
Platform.SENSOR,
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
f"{prefix}_{station_live['name']}_{station_from['name']}_{station_to['name']}",
|
f"{prefix}_{station_live.name}_{station_from.name}_{station_to.name}",
|
||||||
):
|
):
|
||||||
new_unique_id = f"{DOMAIN}_{prefix}_{station_live['id']}_{station_from['id']}_{station_to['id']}{vias}"
|
new_unique_id = f"{DOMAIN}_{prefix}_{station_live.id}_{station_from.id}_{station_to.id}{vias}"
|
||||||
entity_registry.async_update_entity(
|
entity_registry.async_update_entity(
|
||||||
entity_id, new_unique_id=new_unique_id
|
entity_id, new_unique_id=new_unique_id
|
||||||
)
|
)
|
||||||
|
@ -19,11 +19,7 @@ CONF_SHOW_ON_MAP = "show_on_map"
|
|||||||
def find_station_by_name(hass: HomeAssistant, station_name: str):
|
def find_station_by_name(hass: HomeAssistant, station_name: str):
|
||||||
"""Find given station_name in the station list."""
|
"""Find given station_name in the station list."""
|
||||||
return next(
|
return next(
|
||||||
(
|
(s for s in hass.data[DOMAIN] if station_name in (s.standard_name, s.name)),
|
||||||
s
|
|
||||||
for s in hass.data[DOMAIN]
|
|
||||||
if station_name in (s["standardname"], s["name"])
|
|
||||||
),
|
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -31,6 +27,6 @@ def find_station_by_name(hass: HomeAssistant, station_name: str):
|
|||||||
def find_station(hass: HomeAssistant, station_name: str):
|
def find_station(hass: HomeAssistant, station_name: str):
|
||||||
"""Find given station_id in the station list."""
|
"""Find given station_id in the station list."""
|
||||||
return next(
|
return next(
|
||||||
(s for s in hass.data[DOMAIN] if station_name in s["id"]),
|
(s for s in hass.data[DOMAIN] if station_name in s.id),
|
||||||
None,
|
None,
|
||||||
)
|
)
|
||||||
|
@ -7,5 +7,5 @@
|
|||||||
"iot_class": "cloud_polling",
|
"iot_class": "cloud_polling",
|
||||||
"loggers": ["pyrail"],
|
"loggers": ["pyrail"],
|
||||||
"quality_scale": "legacy",
|
"quality_scale": "legacy",
|
||||||
"requirements": ["pyrail==0.0.3"]
|
"requirements": ["pyrail==0.4.1"]
|
||||||
}
|
}
|
||||||
|
@ -2,10 +2,12 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pyrail import iRail
|
from pyrail import iRail
|
||||||
|
from pyrail.models import ConnectionDetails, LiveboardDeparture, StationDetails
|
||||||
import voluptuous as vol
|
import voluptuous as vol
|
||||||
|
|
||||||
from homeassistant.components.sensor import (
|
from homeassistant.components.sensor import (
|
||||||
@ -23,6 +25,7 @@ from homeassistant.const import (
|
|||||||
)
|
)
|
||||||
from homeassistant.core import DOMAIN as HOMEASSISTANT_DOMAIN, HomeAssistant
|
from homeassistant.core import DOMAIN as HOMEASSISTANT_DOMAIN, HomeAssistant
|
||||||
from homeassistant.helpers import config_validation as cv
|
from homeassistant.helpers import config_validation as cv
|
||||||
|
from homeassistant.helpers.aiohttp_client import async_get_clientsession
|
||||||
from homeassistant.helpers.entity_platform import (
|
from homeassistant.helpers.entity_platform import (
|
||||||
AddConfigEntryEntitiesCallback,
|
AddConfigEntryEntitiesCallback,
|
||||||
AddEntitiesCallback,
|
AddEntitiesCallback,
|
||||||
@ -44,8 +47,6 @@ from .const import ( # noqa: F401
|
|||||||
|
|
||||||
_LOGGER = logging.getLogger(__name__)
|
_LOGGER = logging.getLogger(__name__)
|
||||||
|
|
||||||
API_FAILURE = -1
|
|
||||||
|
|
||||||
DEFAULT_NAME = "NMBS"
|
DEFAULT_NAME = "NMBS"
|
||||||
|
|
||||||
DEFAULT_ICON = "mdi:train"
|
DEFAULT_ICON = "mdi:train"
|
||||||
@ -63,12 +64,12 @@ PLATFORM_SCHEMA = SENSOR_PLATFORM_SCHEMA.extend(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_time_until(departure_time=None):
|
def get_time_until(departure_time: datetime | None = None):
|
||||||
"""Calculate the time between now and a train's departure time."""
|
"""Calculate the time between now and a train's departure time."""
|
||||||
if departure_time is None:
|
if departure_time is None:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
delta = dt_util.utc_from_timestamp(int(departure_time)) - dt_util.now()
|
delta = dt_util.as_utc(departure_time) - dt_util.utcnow()
|
||||||
return round(delta.total_seconds() / 60)
|
return round(delta.total_seconds() / 60)
|
||||||
|
|
||||||
|
|
||||||
@ -77,11 +78,9 @@ def get_delay_in_minutes(delay=0):
|
|||||||
return round(int(delay) / 60)
|
return round(int(delay) / 60)
|
||||||
|
|
||||||
|
|
||||||
def get_ride_duration(departure_time, arrival_time, delay=0):
|
def get_ride_duration(departure_time: datetime, arrival_time: datetime, delay=0):
|
||||||
"""Calculate the total travel time in minutes."""
|
"""Calculate the total travel time in minutes."""
|
||||||
duration = dt_util.utc_from_timestamp(
|
duration = arrival_time - departure_time
|
||||||
int(arrival_time)
|
|
||||||
) - dt_util.utc_from_timestamp(int(departure_time))
|
|
||||||
duration_time = int(round(duration.total_seconds() / 60))
|
duration_time = int(round(duration.total_seconds() / 60))
|
||||||
return duration_time + get_delay_in_minutes(delay)
|
return duration_time + get_delay_in_minutes(delay)
|
||||||
|
|
||||||
@ -157,7 +156,7 @@ async def async_setup_entry(
|
|||||||
async_add_entities: AddConfigEntryEntitiesCallback,
|
async_add_entities: AddConfigEntryEntitiesCallback,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set up NMBS sensor entities based on a config entry."""
|
"""Set up NMBS sensor entities based on a config entry."""
|
||||||
api_client = iRail()
|
api_client = iRail(session=async_get_clientsession(hass))
|
||||||
|
|
||||||
name = config_entry.data.get(CONF_NAME, None)
|
name = config_entry.data.get(CONF_NAME, None)
|
||||||
show_on_map = config_entry.data.get(CONF_SHOW_ON_MAP, False)
|
show_on_map = config_entry.data.get(CONF_SHOW_ON_MAP, False)
|
||||||
@ -189,9 +188,9 @@ class NMBSLiveBoard(SensorEntity):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
api_client: iRail,
|
api_client: iRail,
|
||||||
live_station: dict[str, Any],
|
live_station: StationDetails,
|
||||||
station_from: dict[str, Any],
|
station_from: StationDetails,
|
||||||
station_to: dict[str, Any],
|
station_to: StationDetails,
|
||||||
excl_vias: bool,
|
excl_vias: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the sensor for getting liveboard data."""
|
"""Initialize the sensor for getting liveboard data."""
|
||||||
@ -201,7 +200,8 @@ class NMBSLiveBoard(SensorEntity):
|
|||||||
self._station_to = station_to
|
self._station_to = station_to
|
||||||
|
|
||||||
self._excl_vias = excl_vias
|
self._excl_vias = excl_vias
|
||||||
self._attrs: dict[str, Any] | None = {}
|
self._attrs: LiveboardDeparture | None = None
|
||||||
|
|
||||||
self._state: str | None = None
|
self._state: str | None = None
|
||||||
|
|
||||||
self.entity_registry_enabled_default = False
|
self.entity_registry_enabled_default = False
|
||||||
@ -209,22 +209,20 @@ class NMBSLiveBoard(SensorEntity):
|
|||||||
@property
|
@property
|
||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
"""Return the sensor default name."""
|
"""Return the sensor default name."""
|
||||||
return f"Trains in {self._station['standardname']}"
|
return f"Trains in {self._station.standard_name}"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def unique_id(self) -> str:
|
def unique_id(self) -> str:
|
||||||
"""Return the unique ID."""
|
"""Return the unique ID."""
|
||||||
|
|
||||||
unique_id = (
|
unique_id = f"{self._station.id}_{self._station_from.id}_{self._station_to.id}"
|
||||||
f"{self._station['id']}_{self._station_from['id']}_{self._station_to['id']}"
|
|
||||||
)
|
|
||||||
vias = "_excl_vias" if self._excl_vias else ""
|
vias = "_excl_vias" if self._excl_vias else ""
|
||||||
return f"nmbs_live_{unique_id}{vias}"
|
return f"nmbs_live_{unique_id}{vias}"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def icon(self) -> str:
|
def icon(self) -> str:
|
||||||
"""Return the default icon or an alert icon if delays."""
|
"""Return the default icon or an alert icon if delays."""
|
||||||
if self._attrs and int(self._attrs["delay"]) > 0:
|
if self._attrs and int(self._attrs.delay) > 0:
|
||||||
return DEFAULT_ICON_ALERT
|
return DEFAULT_ICON_ALERT
|
||||||
|
|
||||||
return DEFAULT_ICON
|
return DEFAULT_ICON
|
||||||
@ -240,15 +238,15 @@ class NMBSLiveBoard(SensorEntity):
|
|||||||
if self._state is None or not self._attrs:
|
if self._state is None or not self._attrs:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
delay = get_delay_in_minutes(self._attrs["delay"])
|
delay = get_delay_in_minutes(self._attrs.delay)
|
||||||
departure = get_time_until(self._attrs["time"])
|
departure = get_time_until(self._attrs.time)
|
||||||
|
|
||||||
attrs = {
|
attrs = {
|
||||||
"departure": f"In {departure} minutes",
|
"departure": f"In {departure} minutes",
|
||||||
"departure_minutes": departure,
|
"departure_minutes": departure,
|
||||||
"extra_train": int(self._attrs["isExtra"]) > 0,
|
"extra_train": self._attrs.is_extra,
|
||||||
"vehicle_id": self._attrs["vehicle"],
|
"vehicle_id": self._attrs.vehicle,
|
||||||
"monitored_station": self._station["standardname"],
|
"monitored_station": self._station.standard_name,
|
||||||
}
|
}
|
||||||
|
|
||||||
if delay > 0:
|
if delay > 0:
|
||||||
@ -257,28 +255,26 @@ class NMBSLiveBoard(SensorEntity):
|
|||||||
|
|
||||||
return attrs
|
return attrs
|
||||||
|
|
||||||
def update(self) -> None:
|
async def async_update(self, **kwargs: Any) -> None:
|
||||||
"""Set the state equal to the next departure."""
|
"""Set the state equal to the next departure."""
|
||||||
liveboard = self._api_client.get_liveboard(self._station["id"])
|
liveboard = await self._api_client.get_liveboard(self._station.id)
|
||||||
|
|
||||||
if liveboard == API_FAILURE:
|
if liveboard is None:
|
||||||
_LOGGER.warning("API failed in NMBSLiveBoard")
|
_LOGGER.warning("API failed in NMBSLiveBoard")
|
||||||
return
|
return
|
||||||
|
|
||||||
if not (departures := liveboard.get("departures")):
|
if not (departures := liveboard.departures):
|
||||||
_LOGGER.warning("API returned invalid departures: %r", liveboard)
|
_LOGGER.warning("API returned invalid departures: %r", liveboard)
|
||||||
return
|
return
|
||||||
|
|
||||||
_LOGGER.debug("API returned departures: %r", departures)
|
_LOGGER.debug("API returned departures: %r", departures)
|
||||||
if departures["number"] == "0":
|
if len(departures) == 0:
|
||||||
# No trains are scheduled
|
# No trains are scheduled
|
||||||
return
|
return
|
||||||
next_departure = departures["departure"][0]
|
next_departure = departures[0]
|
||||||
|
|
||||||
self._attrs = next_departure
|
self._attrs = next_departure
|
||||||
self._state = (
|
self._state = f"Track {next_departure.platform} - {next_departure.station}"
|
||||||
f"Track {next_departure['platform']} - {next_departure['station']}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class NMBSSensor(SensorEntity):
|
class NMBSSensor(SensorEntity):
|
||||||
@ -292,8 +288,8 @@ class NMBSSensor(SensorEntity):
|
|||||||
api_client: iRail,
|
api_client: iRail,
|
||||||
name: str,
|
name: str,
|
||||||
show_on_map: bool,
|
show_on_map: bool,
|
||||||
station_from: dict[str, Any],
|
station_from: StationDetails,
|
||||||
station_to: dict[str, Any],
|
station_to: StationDetails,
|
||||||
excl_vias: bool,
|
excl_vias: bool,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize the NMBS connection sensor."""
|
"""Initialize the NMBS connection sensor."""
|
||||||
@ -304,13 +300,13 @@ class NMBSSensor(SensorEntity):
|
|||||||
self._station_to = station_to
|
self._station_to = station_to
|
||||||
self._excl_vias = excl_vias
|
self._excl_vias = excl_vias
|
||||||
|
|
||||||
self._attrs: dict[str, Any] | None = {}
|
self._attrs: ConnectionDetails | None = None
|
||||||
self._state = None
|
self._state = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def unique_id(self) -> str:
|
def unique_id(self) -> str:
|
||||||
"""Return the unique ID."""
|
"""Return the unique ID."""
|
||||||
unique_id = f"{self._station_from['id']}_{self._station_to['id']}"
|
unique_id = f"{self._station_from.id}_{self._station_to.id}"
|
||||||
|
|
||||||
vias = "_excl_vias" if self._excl_vias else ""
|
vias = "_excl_vias" if self._excl_vias else ""
|
||||||
return f"nmbs_connection_{unique_id}{vias}"
|
return f"nmbs_connection_{unique_id}{vias}"
|
||||||
@ -319,14 +315,14 @@ class NMBSSensor(SensorEntity):
|
|||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
"""Return the name of the sensor."""
|
"""Return the name of the sensor."""
|
||||||
if self._name is None:
|
if self._name is None:
|
||||||
return f"Train from {self._station_from['standardname']} to {self._station_to['standardname']}"
|
return f"Train from {self._station_from.standard_name} to {self._station_to.standard_name}"
|
||||||
return self._name
|
return self._name
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def icon(self) -> str:
|
def icon(self) -> str:
|
||||||
"""Return the sensor default icon or an alert icon if any delay."""
|
"""Return the sensor default icon or an alert icon if any delay."""
|
||||||
if self._attrs:
|
if self._attrs:
|
||||||
delay = get_delay_in_minutes(self._attrs["departure"]["delay"])
|
delay = get_delay_in_minutes(self._attrs.departure.delay)
|
||||||
if delay > 0:
|
if delay > 0:
|
||||||
return "mdi:alert-octagon"
|
return "mdi:alert-octagon"
|
||||||
|
|
||||||
@ -338,19 +334,19 @@ class NMBSSensor(SensorEntity):
|
|||||||
if self._state is None or not self._attrs:
|
if self._state is None or not self._attrs:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
delay = get_delay_in_minutes(self._attrs["departure"]["delay"])
|
delay = get_delay_in_minutes(self._attrs.departure.delay)
|
||||||
departure = get_time_until(self._attrs["departure"]["time"])
|
departure = get_time_until(self._attrs.departure.time)
|
||||||
canceled = int(self._attrs["departure"]["canceled"])
|
canceled = self._attrs.departure.canceled
|
||||||
|
|
||||||
attrs = {
|
attrs = {
|
||||||
"destination": self._attrs["departure"]["station"],
|
"destination": self._attrs.departure.station,
|
||||||
"direction": self._attrs["departure"]["direction"]["name"],
|
"direction": self._attrs.departure.direction.name,
|
||||||
"platform_arriving": self._attrs["arrival"]["platform"],
|
"platform_arriving": self._attrs.arrival.platform,
|
||||||
"platform_departing": self._attrs["departure"]["platform"],
|
"platform_departing": self._attrs.departure.platform,
|
||||||
"vehicle_id": self._attrs["departure"]["vehicle"],
|
"vehicle_id": self._attrs.departure.vehicle,
|
||||||
}
|
}
|
||||||
|
|
||||||
if canceled != 1:
|
if not canceled:
|
||||||
attrs["departure"] = f"In {departure} minutes"
|
attrs["departure"] = f"In {departure} minutes"
|
||||||
attrs["departure_minutes"] = departure
|
attrs["departure_minutes"] = departure
|
||||||
attrs["canceled"] = False
|
attrs["canceled"] = False
|
||||||
@ -364,14 +360,14 @@ class NMBSSensor(SensorEntity):
|
|||||||
attrs[ATTR_LONGITUDE] = self.station_coordinates[1]
|
attrs[ATTR_LONGITUDE] = self.station_coordinates[1]
|
||||||
|
|
||||||
if self.is_via_connection and not self._excl_vias:
|
if self.is_via_connection and not self._excl_vias:
|
||||||
via = self._attrs["vias"]["via"][0]
|
via = self._attrs.vias.via[0]
|
||||||
|
|
||||||
attrs["via"] = via["station"]
|
attrs["via"] = via.station
|
||||||
attrs["via_arrival_platform"] = via["arrival"]["platform"]
|
attrs["via_arrival_platform"] = via.arrival.platform
|
||||||
attrs["via_transfer_platform"] = via["departure"]["platform"]
|
attrs["via_transfer_platform"] = via.departure.platform
|
||||||
attrs["via_transfer_time"] = get_delay_in_minutes(
|
attrs["via_transfer_time"] = get_delay_in_minutes(
|
||||||
via["timebetween"]
|
via.timebetween
|
||||||
) + get_delay_in_minutes(via["departure"]["delay"])
|
) + get_delay_in_minutes(via.departure.delay)
|
||||||
|
|
||||||
if delay > 0:
|
if delay > 0:
|
||||||
attrs["delay"] = f"{delay} minutes"
|
attrs["delay"] = f"{delay} minutes"
|
||||||
@ -390,8 +386,8 @@ class NMBSSensor(SensorEntity):
|
|||||||
if self._state is None or not self._attrs:
|
if self._state is None or not self._attrs:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
latitude = float(self._attrs["departure"]["stationinfo"]["locationY"])
|
latitude = float(self._attrs.departure.station_info.latitude)
|
||||||
longitude = float(self._attrs["departure"]["stationinfo"]["locationX"])
|
longitude = float(self._attrs.departure.station_info.longitude)
|
||||||
return [latitude, longitude]
|
return [latitude, longitude]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@ -400,24 +396,24 @@ class NMBSSensor(SensorEntity):
|
|||||||
if not self._attrs:
|
if not self._attrs:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return "vias" in self._attrs and int(self._attrs["vias"]["number"]) > 0
|
return self._attrs.vias is not None and len(self._attrs.vias) > 0
|
||||||
|
|
||||||
def update(self) -> None:
|
async def async_update(self, **kwargs: Any) -> None:
|
||||||
"""Set the state to the duration of a connection."""
|
"""Set the state to the duration of a connection."""
|
||||||
connections = self._api_client.get_connections(
|
connections = await self._api_client.get_connections(
|
||||||
self._station_from["id"], self._station_to["id"]
|
self._station_from.id, self._station_to.id
|
||||||
)
|
)
|
||||||
|
|
||||||
if connections == API_FAILURE:
|
if connections is None:
|
||||||
_LOGGER.warning("API failed in NMBSSensor")
|
_LOGGER.warning("API failed in NMBSSensor")
|
||||||
return
|
return
|
||||||
|
|
||||||
if not (connection := connections.get("connection")):
|
if not (connection := connections.connections):
|
||||||
_LOGGER.warning("API returned invalid connection: %r", connections)
|
_LOGGER.warning("API returned invalid connection: %r", connections)
|
||||||
return
|
return
|
||||||
|
|
||||||
_LOGGER.debug("API returned connection: %r", connection)
|
_LOGGER.debug("API returned connection: %r", connection)
|
||||||
if int(connection[0]["departure"]["left"]) > 0:
|
if connection[0].departure.left:
|
||||||
next_connection = connection[1]
|
next_connection = connection[1]
|
||||||
else:
|
else:
|
||||||
next_connection = connection[0]
|
next_connection = connection[0]
|
||||||
@ -431,9 +427,9 @@ class NMBSSensor(SensorEntity):
|
|||||||
return
|
return
|
||||||
|
|
||||||
duration = get_ride_duration(
|
duration = get_ride_duration(
|
||||||
next_connection["departure"]["time"],
|
next_connection.departure.time,
|
||||||
next_connection["arrival"]["time"],
|
next_connection.arrival.time,
|
||||||
next_connection["departure"]["delay"],
|
next_connection.departure.delay,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._state = duration
|
self._state = duration
|
||||||
|
2
requirements_all.txt
generated
2
requirements_all.txt
generated
@ -2244,7 +2244,7 @@ pyqvrpro==0.52
|
|||||||
pyqwikswitch==0.93
|
pyqwikswitch==0.93
|
||||||
|
|
||||||
# homeassistant.components.nmbs
|
# homeassistant.components.nmbs
|
||||||
pyrail==0.0.3
|
pyrail==0.4.1
|
||||||
|
|
||||||
# homeassistant.components.rainbird
|
# homeassistant.components.rainbird
|
||||||
pyrainbird==6.0.1
|
pyrainbird==6.0.1
|
||||||
|
2
requirements_test_all.txt
generated
2
requirements_test_all.txt
generated
@ -1834,7 +1834,7 @@ pyps4-2ndscreen==1.3.1
|
|||||||
pyqwikswitch==0.93
|
pyqwikswitch==0.93
|
||||||
|
|
||||||
# homeassistant.components.nmbs
|
# homeassistant.components.nmbs
|
||||||
pyrail==0.0.3
|
pyrail==0.4.1
|
||||||
|
|
||||||
# homeassistant.components.rainbird
|
# homeassistant.components.rainbird
|
||||||
pyrainbird==6.0.1
|
pyrainbird==6.0.1
|
||||||
|
@ -1,20 +1 @@
|
|||||||
"""Tests for the NMBS integration."""
|
"""Tests for the NMBS integration."""
|
||||||
|
|
||||||
import json
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from tests.common import load_fixture
|
|
||||||
|
|
||||||
|
|
||||||
def mock_api_unavailable() -> dict[str, Any]:
|
|
||||||
"""Mock for unavailable api."""
|
|
||||||
return -1
|
|
||||||
|
|
||||||
|
|
||||||
def mock_station_response() -> dict[str, Any]:
|
|
||||||
"""Mock for valid station response."""
|
|
||||||
dummy_stations_response: dict[str, Any] = json.loads(
|
|
||||||
load_fixture("stations.json", "nmbs")
|
|
||||||
)
|
|
||||||
|
|
||||||
return dummy_stations_response
|
|
||||||
|
@ -3,6 +3,7 @@
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
from pyrail.models import StationsApiResponse
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from homeassistant.components.nmbs.const import (
|
from homeassistant.components.nmbs.const import (
|
||||||
@ -38,8 +39,8 @@ def mock_nmbs_client() -> Generator[AsyncMock]:
|
|||||||
),
|
),
|
||||||
):
|
):
|
||||||
client = mock_client.return_value
|
client = mock_client.return_value
|
||||||
client.get_stations.return_value = load_json_object_fixture(
|
client.get_stations.return_value = StationsApiResponse.from_dict(
|
||||||
"stations.json", DOMAIN
|
load_json_object_fixture("stations.json", DOMAIN)
|
||||||
)
|
)
|
||||||
yield client
|
yield client
|
||||||
|
|
||||||
|
@ -142,7 +142,7 @@ async def test_unavailable_api(
|
|||||||
hass: HomeAssistant, mock_nmbs_client: AsyncMock
|
hass: HomeAssistant, mock_nmbs_client: AsyncMock
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test starting a flow by user and api is unavailable."""
|
"""Test starting a flow by user and api is unavailable."""
|
||||||
mock_nmbs_client.get_stations.return_value = -1
|
mock_nmbs_client.get_stations.return_value = None
|
||||||
result = await hass.config_entries.flow.async_init(
|
result = await hass.config_entries.flow.async_init(
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
context={"source": config_entries.SOURCE_USER},
|
context={"source": config_entries.SOURCE_USER},
|
||||||
@ -203,7 +203,7 @@ async def test_unavailable_api_import(
|
|||||||
hass: HomeAssistant, mock_nmbs_client: AsyncMock
|
hass: HomeAssistant, mock_nmbs_client: AsyncMock
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Test starting a flow by import and api is unavailable."""
|
"""Test starting a flow by import and api is unavailable."""
|
||||||
mock_nmbs_client.get_stations.return_value = -1
|
mock_nmbs_client.get_stations.return_value = None
|
||||||
result = await hass.config_entries.flow.async_init(
|
result = await hass.config_entries.flow.async_init(
|
||||||
DOMAIN,
|
DOMAIN,
|
||||||
context={"source": SOURCE_IMPORT},
|
context={"source": SOURCE_IMPORT},
|
||||||
|
Loading…
x
Reference in New Issue
Block a user