diff --git a/homeassistant/components/swiss_public_transport/__init__.py b/homeassistant/components/swiss_public_transport/__init__.py index 1242c95269e..3e29fb9c746 100644 --- a/homeassistant/components/swiss_public_transport/__init__.py +++ b/homeassistant/components/swiss_public_transport/__init__.py @@ -11,18 +11,32 @@ from opendata_transport.exceptions import ( from homeassistant import config_entries, core from homeassistant.const import Platform from homeassistant.exceptions import ConfigEntryError, ConfigEntryNotReady -from homeassistant.helpers import device_registry as dr, entity_registry as er +from homeassistant.helpers import ( + config_validation as cv, + device_registry as dr, + entity_registry as er, +) from homeassistant.helpers.aiohttp_client import async_get_clientsession +from homeassistant.helpers.typing import ConfigType from .const import CONF_DESTINATION, CONF_START, CONF_VIA, DOMAIN, PLACEHOLDERS from .coordinator import SwissPublicTransportDataUpdateCoordinator from .helper import unique_id_from_config +from .services import setup_services _LOGGER = logging.getLogger(__name__) PLATFORMS: list[Platform] = [Platform.SENSOR] +CONFIG_SCHEMA = cv.config_entry_only_config_schema(DOMAIN) + + +async def async_setup(hass: core.HomeAssistant, config: ConfigType) -> bool: + """Set up the Swiss public transport component.""" + setup_services(hass) + return True + async def async_setup_entry( hass: core.HomeAssistant, entry: config_entries.ConfigEntry diff --git a/homeassistant/components/swiss_public_transport/const.py b/homeassistant/components/swiss_public_transport/const.py index 32b6427ced5..c02f36f2f25 100644 --- a/homeassistant/components/swiss_public_transport/const.py +++ b/homeassistant/components/swiss_public_transport/const.py @@ -9,12 +9,19 @@ CONF_START: Final = "from" CONF_VIA: Final = "via" DEFAULT_NAME = "Next Destination" +DEFAULT_UPDATE_TIME = 90 MAX_VIA = 5 -SENSOR_CONNECTIONS_COUNT = 3 +CONNECTIONS_COUNT = 3 +CONNECTIONS_MAX = 15 PLACEHOLDERS = { "stationboard_url": "http://transport.opendata.ch/examples/stationboard.html", "opendata_url": "http://transport.opendata.ch", } + +ATTR_CONFIG_ENTRY_ID: Final = "config_entry_id" +ATTR_LIMIT: Final = "limit" + +SERVICE_FETCH_CONNECTIONS = "fetch_connections" diff --git a/homeassistant/components/swiss_public_transport/coordinator.py b/homeassistant/components/swiss_public_transport/coordinator.py index ae7e1b2366d..114215520ac 100644 --- a/homeassistant/components/swiss_public_transport/coordinator.py +++ b/homeassistant/components/swiss_public_transport/coordinator.py @@ -7,14 +7,17 @@ import logging from typing import TypedDict from opendata_transport import OpendataTransport -from opendata_transport.exceptions import OpendataTransportError +from opendata_transport.exceptions import ( + OpendataTransportConnectionError, + OpendataTransportError, +) from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, UpdateFailed import homeassistant.util.dt as dt_util -from .const import DOMAIN, SENSOR_CONNECTIONS_COUNT +from .const import CONNECTIONS_COUNT, DEFAULT_UPDATE_TIME, DOMAIN _LOGGER = logging.getLogger(__name__) @@ -54,7 +57,7 @@ class SwissPublicTransportDataUpdateCoordinator( hass, _LOGGER, name=DOMAIN, - update_interval=timedelta(seconds=90), + update_interval=timedelta(seconds=DEFAULT_UPDATE_TIME), ) self._opendata = opendata @@ -74,14 +77,21 @@ class SwissPublicTransportDataUpdateCoordinator( return None async def _async_update_data(self) -> list[DataConnection]: + return await self.fetch_connections(limit=CONNECTIONS_COUNT) + + async def fetch_connections(self, limit: int) -> list[DataConnection]: + """Fetch connections using the opendata api.""" + self._opendata.limit = limit try: await self._opendata.async_get_data() + except OpendataTransportConnectionError as e: + _LOGGER.warning("Connection to transport.opendata.ch cannot be established") + raise UpdateFailed from e except OpendataTransportError as e: _LOGGER.warning( "Unable to connect and retrieve data from transport.opendata.ch" ) raise UpdateFailed from e - connections = self._opendata.connections return [ DataConnection( @@ -95,6 +105,6 @@ class SwissPublicTransportDataUpdateCoordinator( remaining_time=str(self.remaining_time(connections[i]["departure"])), delay=connections[i]["delay"], ) - for i in range(SENSOR_CONNECTIONS_COUNT) + for i in range(limit) if len(connections) > i and connections[i] is not None ] diff --git a/homeassistant/components/swiss_public_transport/icons.json b/homeassistant/components/swiss_public_transport/icons.json index 10573b8f5c3..7c2e5436834 100644 --- a/homeassistant/components/swiss_public_transport/icons.json +++ b/homeassistant/components/swiss_public_transport/icons.json @@ -23,5 +23,8 @@ "default": "mdi:clock-plus" } } + }, + "services": { + "fetch_connections": "mdi:bus-clock" } } diff --git a/homeassistant/components/swiss_public_transport/sensor.py b/homeassistant/components/swiss_public_transport/sensor.py index 88a6dbecae4..c186b963705 100644 --- a/homeassistant/components/swiss_public_transport/sensor.py +++ b/homeassistant/components/swiss_public_transport/sensor.py @@ -20,7 +20,7 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback from homeassistant.helpers.typing import StateType from homeassistant.helpers.update_coordinator import CoordinatorEntity -from .const import DOMAIN, SENSOR_CONNECTIONS_COUNT +from .const import CONNECTIONS_COUNT, DOMAIN from .coordinator import DataConnection, SwissPublicTransportDataUpdateCoordinator _LOGGER = logging.getLogger(__name__) @@ -46,7 +46,7 @@ SENSORS: tuple[SwissPublicTransportSensorEntityDescription, ...] = ( value_fn=lambda data_connection: data_connection["departure"], index=i, ) - for i in range(SENSOR_CONNECTIONS_COUNT) + for i in range(CONNECTIONS_COUNT) ], SwissPublicTransportSensorEntityDescription( key="duration", diff --git a/homeassistant/components/swiss_public_transport/services.py b/homeassistant/components/swiss_public_transport/services.py new file mode 100644 index 00000000000..e8b7c6bd458 --- /dev/null +++ b/homeassistant/components/swiss_public_transport/services.py @@ -0,0 +1,89 @@ +"""Define services for the Swiss public transport integration.""" + +import voluptuous as vol + +from homeassistant import config_entries +from homeassistant.config_entries import ConfigEntryState +from homeassistant.core import ( + HomeAssistant, + ServiceCall, + ServiceResponse, + SupportsResponse, +) +from homeassistant.exceptions import HomeAssistantError, ServiceValidationError +from homeassistant.helpers.selector import ( + NumberSelector, + NumberSelectorConfig, + NumberSelectorMode, +) +from homeassistant.helpers.update_coordinator import UpdateFailed + +from .const import ( + ATTR_CONFIG_ENTRY_ID, + ATTR_LIMIT, + CONNECTIONS_COUNT, + CONNECTIONS_MAX, + DOMAIN, + SERVICE_FETCH_CONNECTIONS, +) + +SERVICE_FETCH_CONNECTIONS_SCHEMA = vol.Schema( + { + vol.Required(ATTR_CONFIG_ENTRY_ID): str, + vol.Optional(ATTR_LIMIT, default=CONNECTIONS_COUNT): NumberSelector( + NumberSelectorConfig( + min=1, max=CONNECTIONS_MAX, mode=NumberSelectorMode.BOX + ) + ), + } +) + + +def async_get_entry( + hass: HomeAssistant, config_entry_id: str +) -> config_entries.ConfigEntry: + """Get the Swiss public transport config entry.""" + if not (entry := hass.config_entries.async_get_entry(config_entry_id)): + raise ServiceValidationError( + translation_domain=DOMAIN, + translation_key="config_entry_not_found", + translation_placeholders={"target": config_entry_id}, + ) + if entry.state is not ConfigEntryState.LOADED: + raise ServiceValidationError( + translation_domain=DOMAIN, + translation_key="not_loaded", + translation_placeholders={"target": entry.title}, + ) + return entry + + +def setup_services(hass: HomeAssistant) -> None: + """Set up the services for the Swiss public transport integration.""" + + async def async_fetch_connections( + call: ServiceCall, + ) -> ServiceResponse: + """Fetch a set of connections.""" + config_entry = async_get_entry(hass, call.data[ATTR_CONFIG_ENTRY_ID]) + limit = call.data.get(ATTR_LIMIT) or CONNECTIONS_COUNT + coordinator = hass.data[DOMAIN][config_entry.entry_id] + try: + connections = await coordinator.fetch_connections(limit=int(limit)) + except UpdateFailed as e: + raise HomeAssistantError( + translation_domain=DOMAIN, + translation_key="cannot_connect", + translation_placeholders={ + "error": str(e), + }, + ) from e + return {"connections": connections} + + hass.services.async_register( + DOMAIN, + SERVICE_FETCH_CONNECTIONS, + async_fetch_connections, + schema=SERVICE_FETCH_CONNECTIONS_SCHEMA, + supports_response=SupportsResponse.ONLY, + ) diff --git a/homeassistant/components/swiss_public_transport/services.yaml b/homeassistant/components/swiss_public_transport/services.yaml new file mode 100644 index 00000000000..d88dad2ca1f --- /dev/null +++ b/homeassistant/components/swiss_public_transport/services.yaml @@ -0,0 +1,14 @@ +fetch_connections: + fields: + config_entry_id: + required: true + selector: + config_entry: + integration: swiss_public_transport + limit: + example: 3 + selector: + number: + min: 1 + max: 15 + step: 1 diff --git a/homeassistant/components/swiss_public_transport/strings.json b/homeassistant/components/swiss_public_transport/strings.json index 4f4bc0522fc..29e73978538 100644 --- a/homeassistant/components/swiss_public_transport/strings.json +++ b/homeassistant/components/swiss_public_transport/strings.json @@ -49,12 +49,37 @@ } } }, + "services": { + "fetch_connections": { + "name": "Fetch Connections", + "description": "Fetch a list of connections from the swiss public transport.", + "fields": { + "config_entry_id": { + "name": "Instance", + "description": "Swiss public transport instance to fetch connections for." + }, + "limit": { + "name": "Limit", + "description": "Number of connections to fetch from [1-15]" + } + } + } + }, "exceptions": { "invalid_data": { "message": "Setup failed for entry {config_title} with invalid data, check at the [stationboard]({stationboard_url}) if your station names are valid.\n{error}" }, "request_timeout": { "message": "Timeout while connecting for entry {config_title}.\n{error}" + }, + "cannot_connect": { + "message": "Cannot connect to server.\n{error}" + }, + "not_loaded": { + "message": "{target} is not loaded." + }, + "config_entry_not_found": { + "message": "Swiss public transport integration instance \"{target}\" not found." } } } diff --git a/script/hassfest/translations.py b/script/hassfest/translations.py index c39c070eba2..c5efd05948f 100644 --- a/script/hassfest/translations.py +++ b/script/hassfest/translations.py @@ -41,6 +41,7 @@ ALLOW_NAME_TRANSLATION = { "local_todo", "nmap_tracker", "rpi_power", + "swiss_public_transport", "waze_travel_time", "zodiac", } diff --git a/tests/components/swiss_public_transport/__init__.py b/tests/components/swiss_public_transport/__init__.py index 3859a630c31..98262324b11 100644 --- a/tests/components/swiss_public_transport/__init__.py +++ b/tests/components/swiss_public_transport/__init__.py @@ -1 +1,13 @@ """Tests for the swiss_public_transport integration.""" + +from homeassistant.core import HomeAssistant + +from tests.common import MockConfigEntry + + +async def setup_integration(hass: HomeAssistant, config_entry: MockConfigEntry) -> None: + """Fixture for setting up the component.""" + config_entry.add_to_hass(hass) + + await hass.config_entries.async_setup(config_entry.entry_id) + await hass.async_block_till_done() diff --git a/tests/components/swiss_public_transport/fixtures/connections.json b/tests/components/swiss_public_transport/fixtures/connections.json new file mode 100644 index 00000000000..4edead56f14 --- /dev/null +++ b/tests/components/swiss_public_transport/fixtures/connections.json @@ -0,0 +1,130 @@ +[ + { + "departure": "2024-01-06T18:03:00+0100", + "number": 0, + "platform": 0, + "transfers": 0, + "duration": "10", + "delay": 0 + }, + { + "departure": "2024-01-06T18:04:00+0100", + "number": 1, + "platform": 1, + "transfers": 0, + "duration": "10", + "delay": 0 + }, + { + "departure": "2024-01-06T18:05:00+0100", + "number": 2, + "platform": 2, + "transfers": 0, + "duration": "10", + "delay": 0 + }, + { + "departure": "2024-01-06T18:06:00+0100", + "number": 3, + "platform": 3, + "transfers": 0, + "duration": "10", + "delay": 0 + }, + { + "departure": "2024-01-06T18:07:00+0100", + "number": 4, + "platform": 4, + "transfers": 0, + "duration": "10", + "delay": 0 + }, + { + "departure": "2024-01-06T18:08:00+0100", + "number": 5, + "platform": 5, + "transfers": 0, + "duration": "10", + "delay": 0 + }, + { + "departure": "2024-01-06T18:09:00+0100", + "number": 6, + "platform": 6, + "transfers": 0, + "duration": "10", + "delay": 0 + }, + { + "departure": "2024-01-06T18:10:00+0100", + "number": 7, + "platform": 7, + "transfers": 0, + "duration": "10", + "delay": 0 + }, + { + "departure": "2024-01-06T18:11:00+0100", + "number": 8, + "platform": 8, + "transfers": 0, + "duration": "10", + "delay": 0 + }, + { + "departure": "2024-01-06T18:12:00+0100", + "number": 9, + "platform": 9, + "transfers": 0, + "duration": "10", + "delay": 0 + }, + { + "departure": "2024-01-06T18:13:00+0100", + "number": 10, + "platform": 10, + "transfers": 0, + "duration": "10", + "delay": 0 + }, + { + "departure": "2024-01-06T18:14:00+0100", + "number": 11, + "platform": 11, + "transfers": 0, + "duration": "10", + "delay": 0 + }, + { + "departure": "2024-01-06T18:15:00+0100", + "number": 12, + "platform": 12, + "transfers": 0, + "duration": "10", + "delay": 0 + }, + { + "departure": "2024-01-06T18:16:00+0100", + "number": 13, + "platform": 13, + "transfers": 0, + "duration": "10", + "delay": 0 + }, + { + "departure": "2024-01-06T18:17:00+0100", + "number": 14, + "platform": 14, + "transfers": 0, + "duration": "10", + "delay": 0 + }, + { + "departure": "2024-01-06T18:18:00+0100", + "number": 15, + "platform": 15, + "transfers": 0, + "duration": "10", + "delay": 0 + } +] diff --git a/tests/components/swiss_public_transport/test_init.py b/tests/components/swiss_public_transport/test_init.py index 47360f93cf2..7ee8b696499 100644 --- a/tests/components/swiss_public_transport/test_init.py +++ b/tests/components/swiss_public_transport/test_init.py @@ -1,4 +1,4 @@ -"""Test the swiss_public_transport config flow.""" +"""Test the swiss_public_transport integration.""" from unittest.mock import AsyncMock, patch diff --git a/tests/components/swiss_public_transport/test_service.py b/tests/components/swiss_public_transport/test_service.py new file mode 100644 index 00000000000..34640de9f21 --- /dev/null +++ b/tests/components/swiss_public_transport/test_service.py @@ -0,0 +1,226 @@ +"""Test the swiss_public_transport service.""" + +import json +import logging +from unittest.mock import AsyncMock, patch + +from opendata_transport.exceptions import ( + OpendataTransportConnectionError, + OpendataTransportError, +) +import pytest +from voluptuous import error as vol_er + +from homeassistant.components.swiss_public_transport.const import ( + ATTR_CONFIG_ENTRY_ID, + ATTR_LIMIT, + CONF_DESTINATION, + CONF_START, + CONNECTIONS_COUNT, + CONNECTIONS_MAX, + DOMAIN, + SERVICE_FETCH_CONNECTIONS, +) +from homeassistant.components.swiss_public_transport.helper import unique_id_from_config +from homeassistant.core import HomeAssistant +from homeassistant.exceptions import HomeAssistantError, ServiceValidationError + +from . import setup_integration + +from tests.common import MockConfigEntry, load_fixture + +_LOGGER = logging.getLogger(__name__) + +MOCK_DATA_STEP_BASE = { + CONF_START: "test_start", + CONF_DESTINATION: "test_destination", +} + + +@pytest.mark.parametrize( + ("limit", "config_data"), + [ + (1, MOCK_DATA_STEP_BASE), + (2, MOCK_DATA_STEP_BASE), + (3, MOCK_DATA_STEP_BASE), + (CONNECTIONS_MAX, MOCK_DATA_STEP_BASE), + (None, MOCK_DATA_STEP_BASE), + ], +) +async def test_service_call_fetch_connections_success( + hass: HomeAssistant, + limit: int, + config_data, +) -> None: + """Test the fetch_connections service.""" + + unique_id = unique_id_from_config(config_data) + + config_entry = MockConfigEntry( + domain=DOMAIN, + data=config_data, + title=f"Service test call with limit={limit}", + unique_id=unique_id, + entry_id=f"entry_{unique_id}", + ) + + with patch( + "homeassistant.components.swiss_public_transport.OpendataTransport", + return_value=AsyncMock(), + ) as mock: + mock().connections = json.loads(load_fixture("connections.json", DOMAIN))[ + 0 : (limit or CONNECTIONS_COUNT) + 2 + ] + + await setup_integration(hass, config_entry) + + data = {ATTR_CONFIG_ENTRY_ID: config_entry.entry_id} + if limit is not None: + data[ATTR_LIMIT] = limit + assert hass.services.has_service(DOMAIN, SERVICE_FETCH_CONNECTIONS) + response = await hass.services.async_call( + domain=DOMAIN, + service=SERVICE_FETCH_CONNECTIONS, + service_data=data, + blocking=True, + return_response=True, + ) + await hass.async_block_till_done() + assert response["connections"] is not None + assert len(response["connections"]) == (limit or CONNECTIONS_COUNT) + + +@pytest.mark.parametrize( + ("limit", "config_data", "expected_result", "raise_error"), + [ + (-1, MOCK_DATA_STEP_BASE, pytest.raises(vol_er.MultipleInvalid), None), + (0, MOCK_DATA_STEP_BASE, pytest.raises(vol_er.MultipleInvalid), None), + ( + CONNECTIONS_MAX + 1, + MOCK_DATA_STEP_BASE, + pytest.raises(vol_er.MultipleInvalid), + None, + ), + ( + 1, + MOCK_DATA_STEP_BASE, + pytest.raises(HomeAssistantError), + OpendataTransportConnectionError(), + ), + ( + 2, + MOCK_DATA_STEP_BASE, + pytest.raises(HomeAssistantError), + OpendataTransportError(), + ), + ], +) +async def test_service_call_fetch_connections_error( + hass: HomeAssistant, + limit, + config_data, + expected_result, + raise_error, +) -> None: + """Test service call with standard error.""" + + unique_id = unique_id_from_config(config_data) + + config_entry = MockConfigEntry( + domain=DOMAIN, + data=config_data, + title=f"Service test call with limit={limit} and error={raise_error}", + unique_id=unique_id, + entry_id=f"entry_{unique_id}", + ) + + with patch( + "homeassistant.components.swiss_public_transport.OpendataTransport", + return_value=AsyncMock(), + ) as mock: + mock().connections = json.loads(load_fixture("connections.json", DOMAIN)) + + await setup_integration(hass, config_entry) + + assert hass.services.has_service(DOMAIN, SERVICE_FETCH_CONNECTIONS) + mock().async_get_data.side_effect = raise_error + with expected_result: + await hass.services.async_call( + domain=DOMAIN, + service=SERVICE_FETCH_CONNECTIONS, + service_data={ + ATTR_CONFIG_ENTRY_ID: config_entry.entry_id, + ATTR_LIMIT: limit, + }, + blocking=True, + return_response=True, + ) + + +async def test_service_call_load_unload( + hass: HomeAssistant, +) -> None: + """Test service call with integration error.""" + + unique_id = unique_id_from_config(MOCK_DATA_STEP_BASE) + + config_entry = MockConfigEntry( + domain=DOMAIN, + data=MOCK_DATA_STEP_BASE, + title="Service test call for unloaded entry", + unique_id=unique_id, + entry_id=f"entry_{unique_id}", + ) + + bad_entry_id = "bad_entry_id" + + with patch( + "homeassistant.components.swiss_public_transport.OpendataTransport", + return_value=AsyncMock(), + ) as mock: + mock().connections = json.loads(load_fixture("connections.json", DOMAIN)) + + await setup_integration(hass, config_entry) + + assert hass.services.has_service(DOMAIN, SERVICE_FETCH_CONNECTIONS) + response = await hass.services.async_call( + domain=DOMAIN, + service=SERVICE_FETCH_CONNECTIONS, + service_data={ + ATTR_CONFIG_ENTRY_ID: config_entry.entry_id, + }, + blocking=True, + return_response=True, + ) + await hass.async_block_till_done() + assert response["connections"] is not None + + await hass.config_entries.async_unload(config_entry.entry_id) + await hass.async_block_till_done() + + with pytest.raises( + ServiceValidationError, match=f"{config_entry.title} is not loaded" + ): + await hass.services.async_call( + domain=DOMAIN, + service=SERVICE_FETCH_CONNECTIONS, + service_data={ + ATTR_CONFIG_ENTRY_ID: config_entry.entry_id, + }, + blocking=True, + return_response=True, + ) + + with pytest.raises( + ServiceValidationError, + match=f'Swiss public transport integration instance "{bad_entry_id}" not found', + ): + await hass.services.async_call( + domain=DOMAIN, + service=SERVICE_FETCH_CONNECTIONS, + service_data={ + ATTR_CONFIG_ENTRY_ID: bad_entry_id, + }, + blocking=True, + return_response=True, + )