Add Swiss public transport fetch connections service (#114671)

* add service to fetch more connections

* improve error messages

* better errors

* wip

* fix service register

* add working tests

* improve tests

* temp availability

* test availability

* remove availability test

* change error type for coordinator update

* fix missed coverage

* convert from entity service to integration service

* cleanup changes

* add more tests for the service
This commit is contained in:
Cyrill Raccaud 2024-08-12 11:26:42 +02:00 committed by GitHub
parent 8cfac68317
commit 0803ac9b0b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 541 additions and 10 deletions

View File

@ -11,18 +11,32 @@ from opendata_transport.exceptions import (
from homeassistant import config_entries, core
from homeassistant.const import Platform
from homeassistant.exceptions import ConfigEntryError, ConfigEntryNotReady
from homeassistant.helpers import device_registry as dr, entity_registry as er
from homeassistant.helpers import (
config_validation as cv,
device_registry as dr,
entity_registry as er,
)
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from homeassistant.helpers.typing import ConfigType
from .const import CONF_DESTINATION, CONF_START, CONF_VIA, DOMAIN, PLACEHOLDERS
from .coordinator import SwissPublicTransportDataUpdateCoordinator
from .helper import unique_id_from_config
from .services import setup_services
_LOGGER = logging.getLogger(__name__)
PLATFORMS: list[Platform] = [Platform.SENSOR]
CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN)
async def async_setup(hass: core.HomeAssistant, config: ConfigType) -> bool:
"""Set up the Swiss public transport component."""
setup_services(hass)
return True
async def async_setup_entry(
hass: core.HomeAssistant, entry: config_entries.ConfigEntry

View File

@ -9,12 +9,19 @@ CONF_START: Final = "from"
CONF_VIA: Final = "via"
DEFAULT_NAME = "Next Destination"
DEFAULT_UPDATE_TIME = 90
MAX_VIA = 5
SENSOR_CONNECTIONS_COUNT = 3
CONNECTIONS_COUNT = 3
CONNECTIONS_MAX = 15
PLACEHOLDERS = {
"stationboard_url": "http://transport.opendata.ch/examples/stationboard.html",
"opendata_url": "http://transport.opendata.ch",
}
ATTR_CONFIG_ENTRY_ID: Final = "config_entry_id"
ATTR_LIMIT: Final = "limit"
SERVICE_FETCH_CONNECTIONS = "fetch_connections"

View File

@ -7,14 +7,17 @@ import logging
from typing import TypedDict
from opendata_transport import OpendataTransport
from opendata_transport.exceptions import OpendataTransportError
from opendata_transport.exceptions import (
OpendataTransportConnectionError,
OpendataTransportError,
)
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed
import homeassistant.util.dt as dt_util
from .const import DOMAIN, SENSOR_CONNECTIONS_COUNT
from .const import CONNECTIONS_COUNT, DEFAULT_UPDATE_TIME, DOMAIN
_LOGGER = logging.getLogger(__name__)
@ -54,7 +57,7 @@ class SwissPublicTransportDataUpdateCoordinator(
hass,
_LOGGER,
name=DOMAIN,
update_interval=timedelta(seconds=90),
update_interval=timedelta(seconds=DEFAULT_UPDATE_TIME),
)
self._opendata = opendata
@ -74,14 +77,21 @@ class SwissPublicTransportDataUpdateCoordinator(
return None
async def _async_update_data(self) -> list[DataConnection]:
return await self.fetch_connections(limit=CONNECTIONS_COUNT)
async def fetch_connections(self, limit: int) -> list[DataConnection]:
"""Fetch connections using the opendata api."""
self._opendata.limit = limit
try:
await self._opendata.async_get_data()
except OpendataTransportConnectionError as e:
_LOGGER.warning("Connection to transport.opendata.ch cannot be established")
raise UpdateFailed from e
except OpendataTransportError as e:
_LOGGER.warning(
"Unable to connect and retrieve data from transport.opendata.ch"
)
raise UpdateFailed from e
connections = self._opendata.connections
return [
DataConnection(
@ -95,6 +105,6 @@ class SwissPublicTransportDataUpdateCoordinator(
remaining_time=str(self.remaining_time(connections[i]["departure"])),
delay=connections[i]["delay"],
)
for i in range(SENSOR_CONNECTIONS_COUNT)
for i in range(limit)
if len(connections) > i and connections[i] is not None
]

View File

@ -23,5 +23,8 @@
"default": "mdi:clock-plus"
}
}
},
"services": {
"fetch_connections": "mdi:bus-clock"
}
}

View File

@ -20,7 +20,7 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.typing import StateType
from homeassistant.helpers.update_coordinator import CoordinatorEntity
from .const import DOMAIN, SENSOR_CONNECTIONS_COUNT
from .const import CONNECTIONS_COUNT, DOMAIN
from .coordinator import DataConnection, SwissPublicTransportDataUpdateCoordinator
_LOGGER = logging.getLogger(__name__)
@ -46,7 +46,7 @@ SENSORS: tuple[SwissPublicTransportSensorEntityDescription, ...] = (
value_fn=lambda data_connection: data_connection["departure"],
index=i,
)
for i in range(SENSOR_CONNECTIONS_COUNT)
for i in range(CONNECTIONS_COUNT)
],
SwissPublicTransportSensorEntityDescription(
key="duration",

View File

@ -0,0 +1,89 @@
"""Define services for the Swiss public transport integration."""
import voluptuous as vol
from homeassistant import config_entries
from homeassistant.config_entries import ConfigEntryState
from homeassistant.core import (
HomeAssistant,
ServiceCall,
ServiceResponse,
SupportsResponse,
)
from homeassistant.exceptions import HomeAssistantError, ServiceValidationError
from homeassistant.helpers.selector import (
NumberSelector,
NumberSelectorConfig,
NumberSelectorMode,
)
from homeassistant.helpers.update_coordinator import UpdateFailed
from .const import (
ATTR_CONFIG_ENTRY_ID,
ATTR_LIMIT,
CONNECTIONS_COUNT,
CONNECTIONS_MAX,
DOMAIN,
SERVICE_FETCH_CONNECTIONS,
)
SERVICE_FETCH_CONNECTIONS_SCHEMA = vol.Schema(
{
vol.Required(ATTR_CONFIG_ENTRY_ID): str,
vol.Optional(ATTR_LIMIT, default=CONNECTIONS_COUNT): NumberSelector(
NumberSelectorConfig(
min=1, max=CONNECTIONS_MAX, mode=NumberSelectorMode.BOX
)
),
}
)
def async_get_entry(
hass: HomeAssistant, config_entry_id: str
) -> config_entries.ConfigEntry:
"""Get the Swiss public transport config entry."""
if not (entry := hass.config_entries.async_get_entry(config_entry_id)):
raise ServiceValidationError(
translation_domain=DOMAIN,
translation_key="config_entry_not_found",
translation_placeholders={"target": config_entry_id},
)
if entry.state is not ConfigEntryState.LOADED:
raise ServiceValidationError(
translation_domain=DOMAIN,
translation_key="not_loaded",
translation_placeholders={"target": entry.title},
)
return entry
def setup_services(hass: HomeAssistant) -> None:
"""Set up the services for the Swiss public transport integration."""
async def async_fetch_connections(
call: ServiceCall,
) -> ServiceResponse:
"""Fetch a set of connections."""
config_entry = async_get_entry(hass, call.data[ATTR_CONFIG_ENTRY_ID])
limit = call.data.get(ATTR_LIMIT) or CONNECTIONS_COUNT
coordinator = hass.data[DOMAIN][config_entry.entry_id]
try:
connections = await coordinator.fetch_connections(limit=int(limit))
except UpdateFailed as e:
raise HomeAssistantError(
translation_domain=DOMAIN,
translation_key="cannot_connect",
translation_placeholders={
"error": str(e),
},
) from e
return {"connections": connections}
hass.services.async_register(
DOMAIN,
SERVICE_FETCH_CONNECTIONS,
async_fetch_connections,
schema=SERVICE_FETCH_CONNECTIONS_SCHEMA,
supports_response=SupportsResponse.ONLY,
)

View File

@ -0,0 +1,14 @@
fetch_connections:
fields:
config_entry_id:
required: true
selector:
config_entry:
integration: swiss_public_transport
limit:
example: 3
selector:
number:
min: 1
max: 15
step: 1

View File

@ -49,12 +49,37 @@
}
}
},
"services": {
"fetch_connections": {
"name": "Fetch Connections",
"description": "Fetch a list of connections from the swiss public transport.",
"fields": {
"config_entry_id": {
"name": "Instance",
"description": "Swiss public transport instance to fetch connections for."
},
"limit": {
"name": "Limit",
"description": "Number of connections to fetch from [1-15]"
}
}
}
},
"exceptions": {
"invalid_data": {
"message": "Setup failed for entry {config_title} with invalid data, check at the [stationboard]({stationboard_url}) if your station names are valid.\n{error}"
},
"request_timeout": {
"message": "Timeout while connecting for entry {config_title}.\n{error}"
},
"cannot_connect": {
"message": "Cannot connect to server.\n{error}"
},
"not_loaded": {
"message": "{target} is not loaded."
},
"config_entry_not_found": {
"message": "Swiss public transport integration instance \"{target}\" not found."
}
}
}

View File

@ -41,6 +41,7 @@ ALLOW_NAME_TRANSLATION = {
"local_todo",
"nmap_tracker",
"rpi_power",
"swiss_public_transport",
"waze_travel_time",
"zodiac",
}

View File

@ -1 +1,13 @@
"""Tests for the swiss_public_transport integration."""
from homeassistant.core import HomeAssistant
from tests.common import MockConfigEntry
async def setup_integration(hass: HomeAssistant, config_entry: MockConfigEntry) -> None:
"""Fixture for setting up the component."""
config_entry.add_to_hass(hass)
await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done()

View File

@ -0,0 +1,130 @@
[
{
"departure": "2024-01-06T18:03:00+0100",
"number": 0,
"platform": 0,
"transfers": 0,
"duration": "10",
"delay": 0
},
{
"departure": "2024-01-06T18:04:00+0100",
"number": 1,
"platform": 1,
"transfers": 0,
"duration": "10",
"delay": 0
},
{
"departure": "2024-01-06T18:05:00+0100",
"number": 2,
"platform": 2,
"transfers": 0,
"duration": "10",
"delay": 0
},
{
"departure": "2024-01-06T18:06:00+0100",
"number": 3,
"platform": 3,
"transfers": 0,
"duration": "10",
"delay": 0
},
{
"departure": "2024-01-06T18:07:00+0100",
"number": 4,
"platform": 4,
"transfers": 0,
"duration": "10",
"delay": 0
},
{
"departure": "2024-01-06T18:08:00+0100",
"number": 5,
"platform": 5,
"transfers": 0,
"duration": "10",
"delay": 0
},
{
"departure": "2024-01-06T18:09:00+0100",
"number": 6,
"platform": 6,
"transfers": 0,
"duration": "10",
"delay": 0
},
{
"departure": "2024-01-06T18:10:00+0100",
"number": 7,
"platform": 7,
"transfers": 0,
"duration": "10",
"delay": 0
},
{
"departure": "2024-01-06T18:11:00+0100",
"number": 8,
"platform": 8,
"transfers": 0,
"duration": "10",
"delay": 0
},
{
"departure": "2024-01-06T18:12:00+0100",
"number": 9,
"platform": 9,
"transfers": 0,
"duration": "10",
"delay": 0
},
{
"departure": "2024-01-06T18:13:00+0100",
"number": 10,
"platform": 10,
"transfers": 0,
"duration": "10",
"delay": 0
},
{
"departure": "2024-01-06T18:14:00+0100",
"number": 11,
"platform": 11,
"transfers": 0,
"duration": "10",
"delay": 0
},
{
"departure": "2024-01-06T18:15:00+0100",
"number": 12,
"platform": 12,
"transfers": 0,
"duration": "10",
"delay": 0
},
{
"departure": "2024-01-06T18:16:00+0100",
"number": 13,
"platform": 13,
"transfers": 0,
"duration": "10",
"delay": 0
},
{
"departure": "2024-01-06T18:17:00+0100",
"number": 14,
"platform": 14,
"transfers": 0,
"duration": "10",
"delay": 0
},
{
"departure": "2024-01-06T18:18:00+0100",
"number": 15,
"platform": 15,
"transfers": 0,
"duration": "10",
"delay": 0
}
]

View File

@ -1,4 +1,4 @@
"""Test the swiss_public_transport config flow."""
"""Test the swiss_public_transport integration."""
from unittest.mock import AsyncMock, patch

View File

@ -0,0 +1,226 @@
"""Test the swiss_public_transport service."""
import json
import logging
from unittest.mock import AsyncMock, patch
from opendata_transport.exceptions import (
OpendataTransportConnectionError,
OpendataTransportError,
)
import pytest
from voluptuous import error as vol_er
from homeassistant.components.swiss_public_transport.const import (
ATTR_CONFIG_ENTRY_ID,
ATTR_LIMIT,
CONF_DESTINATION,
CONF_START,
CONNECTIONS_COUNT,
CONNECTIONS_MAX,
DOMAIN,
SERVICE_FETCH_CONNECTIONS,
)
from homeassistant.components.swiss_public_transport.helper import unique_id_from_config
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import HomeAssistantError, ServiceValidationError
from . import setup_integration
from tests.common import MockConfigEntry, load_fixture
_LOGGER = logging.getLogger(__name__)
MOCK_DATA_STEP_BASE = {
CONF_START: "test_start",
CONF_DESTINATION: "test_destination",
}
@pytest.mark.parametrize(
("limit", "config_data"),
[
(1, MOCK_DATA_STEP_BASE),
(2, MOCK_DATA_STEP_BASE),
(3, MOCK_DATA_STEP_BASE),
(CONNECTIONS_MAX, MOCK_DATA_STEP_BASE),
(None, MOCK_DATA_STEP_BASE),
],
)
async def test_service_call_fetch_connections_success(
hass: HomeAssistant,
limit: int,
config_data,
) -> None:
"""Test the fetch_connections service."""
unique_id = unique_id_from_config(config_data)
config_entry = MockConfigEntry(
domain=DOMAIN,
data=config_data,
title=f"Service test call with limit={limit}",
unique_id=unique_id,
entry_id=f"entry_{unique_id}",
)
with patch(
"homeassistant.components.swiss_public_transport.OpendataTransport",
return_value=AsyncMock(),
) as mock:
mock().connections = json.loads(load_fixture("connections.json", DOMAIN))[
0 : (limit or CONNECTIONS_COUNT) + 2
]
await setup_integration(hass, config_entry)
data = {ATTR_CONFIG_ENTRY_ID: config_entry.entry_id}
if limit is not None:
data[ATTR_LIMIT] = limit
assert hass.services.has_service(DOMAIN, SERVICE_FETCH_CONNECTIONS)
response = await hass.services.async_call(
domain=DOMAIN,
service=SERVICE_FETCH_CONNECTIONS,
service_data=data,
blocking=True,
return_response=True,
)
await hass.async_block_till_done()
assert response["connections"] is not None
assert len(response["connections"]) == (limit or CONNECTIONS_COUNT)
@pytest.mark.parametrize(
("limit", "config_data", "expected_result", "raise_error"),
[
(-1, MOCK_DATA_STEP_BASE, pytest.raises(vol_er.MultipleInvalid), None),
(0, MOCK_DATA_STEP_BASE, pytest.raises(vol_er.MultipleInvalid), None),
(
CONNECTIONS_MAX + 1,
MOCK_DATA_STEP_BASE,
pytest.raises(vol_er.MultipleInvalid),
None,
),
(
1,
MOCK_DATA_STEP_BASE,
pytest.raises(HomeAssistantError),
OpendataTransportConnectionError(),
),
(
2,
MOCK_DATA_STEP_BASE,
pytest.raises(HomeAssistantError),
OpendataTransportError(),
),
],
)
async def test_service_call_fetch_connections_error(
hass: HomeAssistant,
limit,
config_data,
expected_result,
raise_error,
) -> None:
"""Test service call with standard error."""
unique_id = unique_id_from_config(config_data)
config_entry = MockConfigEntry(
domain=DOMAIN,
data=config_data,
title=f"Service test call with limit={limit} and error={raise_error}",
unique_id=unique_id,
entry_id=f"entry_{unique_id}",
)
with patch(
"homeassistant.components.swiss_public_transport.OpendataTransport",
return_value=AsyncMock(),
) as mock:
mock().connections = json.loads(load_fixture("connections.json", DOMAIN))
await setup_integration(hass, config_entry)
assert hass.services.has_service(DOMAIN, SERVICE_FETCH_CONNECTIONS)
mock().async_get_data.side_effect = raise_error
with expected_result:
await hass.services.async_call(
domain=DOMAIN,
service=SERVICE_FETCH_CONNECTIONS,
service_data={
ATTR_CONFIG_ENTRY_ID: config_entry.entry_id,
ATTR_LIMIT: limit,
},
blocking=True,
return_response=True,
)
async def test_service_call_load_unload(
hass: HomeAssistant,
) -> None:
"""Test service call with integration error."""
unique_id = unique_id_from_config(MOCK_DATA_STEP_BASE)
config_entry = MockConfigEntry(
domain=DOMAIN,
data=MOCK_DATA_STEP_BASE,
title="Service test call for unloaded entry",
unique_id=unique_id,
entry_id=f"entry_{unique_id}",
)
bad_entry_id = "bad_entry_id"
with patch(
"homeassistant.components.swiss_public_transport.OpendataTransport",
return_value=AsyncMock(),
) as mock:
mock().connections = json.loads(load_fixture("connections.json", DOMAIN))
await setup_integration(hass, config_entry)
assert hass.services.has_service(DOMAIN, SERVICE_FETCH_CONNECTIONS)
response = await hass.services.async_call(
domain=DOMAIN,
service=SERVICE_FETCH_CONNECTIONS,
service_data={
ATTR_CONFIG_ENTRY_ID: config_entry.entry_id,
},
blocking=True,
return_response=True,
)
await hass.async_block_till_done()
assert response["connections"] is not None
await hass.config_entries.async_unload(config_entry.entry_id)
await hass.async_block_till_done()
with pytest.raises(
ServiceValidationError, match=f"{config_entry.title} is not loaded"
):
await hass.services.async_call(
domain=DOMAIN,
service=SERVICE_FETCH_CONNECTIONS,
service_data={
ATTR_CONFIG_ENTRY_ID: config_entry.entry_id,
},
blocking=True,
return_response=True,
)
with pytest.raises(
ServiceValidationError,
match=f'Swiss public transport integration instance "{bad_entry_id}" not found',
):
await hass.services.async_call(
domain=DOMAIN,
service=SERVICE_FETCH_CONNECTIONS,
service_data={
ATTR_CONFIG_ENTRY_ID: bad_entry_id,
},
blocking=True,
return_response=True,
)