Change Trafikverket Train to use station signatures (#131416)

Co-authored-by: Robert Resch <robert@resch.dev>
This commit is contained in:
G Johansson 2025-01-13 15:38:02 +01:00 committed by GitHub
parent 157548609b
commit 4709a3162c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 516 additions and 220 deletions

View File

@ -4,11 +4,21 @@ from __future__ import annotations
import logging import logging
from homeassistant.config_entries import ConfigEntry from pytrafikverket import (
from homeassistant.core import HomeAssistant InvalidAuthentication,
from homeassistant.helpers import entity_registry as er NoTrainStationFound,
TrafikverketTrain,
UnknownError,
)
from .const import PLATFORMS from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_API_KEY
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryAuthFailed
from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from .const import CONF_FROM, CONF_TO, PLATFORMS
from .coordinator import TVDataUpdateCoordinator from .coordinator import TVDataUpdateCoordinator
TVTrainConfigEntry = ConfigEntry[TVDataUpdateCoordinator] TVTrainConfigEntry = ConfigEntry[TVDataUpdateCoordinator]
@ -52,13 +62,55 @@ async def async_migrate_entry(hass: HomeAssistant, entry: TVTrainConfigEntry) ->
"""Migrate config entry.""" """Migrate config entry."""
_LOGGER.debug("Migrating from version %s", entry.version) _LOGGER.debug("Migrating from version %s", entry.version)
if entry.version > 1: if entry.version > 2:
# This means the user has downgraded from a future version # This means the user has downgraded from a future version
return False return False
if entry.version == 1 and entry.minor_version == 1: if entry.version == 1:
# Remove unique id if entry.minor_version == 1:
hass.config_entries.async_update_entry(entry, unique_id=None, minor_version=2) # Remove unique id
hass.config_entries.async_update_entry(
entry, unique_id=None, minor_version=2
)
# Change from station names to station signatures
try:
web_session = async_get_clientsession(hass)
train_api = TrafikverketTrain(web_session, entry.data[CONF_API_KEY])
from_stations = await train_api.async_search_train_stations(
entry.data[CONF_FROM]
)
to_stations = await train_api.async_search_train_stations(
entry.data[CONF_TO]
)
except InvalidAuthentication as error:
raise ConfigEntryAuthFailed from error
except NoTrainStationFound as error:
_LOGGER.error(
"Migration failed as no train station found with provided name %s",
str(error),
)
return False
except UnknownError as error:
_LOGGER.error("Unknown error occurred during validation %s", str(error))
return False
except Exception as error: # noqa: BLE001
_LOGGER.error("Unknown exception occurred during validation %s", str(error))
return False
if len(from_stations) > 1 or len(to_stations) > 1:
_LOGGER.error(
"Migration failed as more than one station found with provided name"
)
return False
new_data = entry.data.copy()
new_data[CONF_FROM] = from_stations[0].signature
new_data[CONF_TO] = to_stations[0].signature
hass.config_entries.async_update_entry(
entry, data=new_data, version=2, minor_version=1
)
_LOGGER.debug( _LOGGER.debug(
"Migration to version %s.%s successful", "Migration to version %s.%s successful",

View File

@ -3,16 +3,14 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Mapping from collections.abc import Mapping
from datetime import datetime
import logging import logging
from typing import Any from typing import Any
from pytrafikverket import TrafikverketTrain from pytrafikverket import (
from pytrafikverket.exceptions import (
InvalidAuthentication, InvalidAuthentication,
MultipleTrainStationsFound,
NoTrainAnnouncementFound,
NoTrainStationFound, NoTrainStationFound,
StationInfoModel,
TrafikverketTrain,
UnknownError, UnknownError,
) )
import voluptuous as vol import voluptuous as vol
@ -28,16 +26,15 @@ from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.aiohttp_client import async_get_clientsession
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.selector import ( from homeassistant.helpers.selector import (
SelectOptionDict,
SelectSelector, SelectSelector,
SelectSelectorConfig, SelectSelectorConfig,
SelectSelectorMode, SelectSelectorMode,
TextSelector, TextSelector,
TimeSelector, TimeSelector,
) )
import homeassistant.util.dt as dt_util
from .const import CONF_FILTER_PRODUCT, CONF_FROM, CONF_TIME, CONF_TO, DOMAIN from .const import CONF_FILTER_PRODUCT, CONF_FROM, CONF_TIME, CONF_TO, DOMAIN
from .util import next_departuredate
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -68,49 +65,23 @@ DATA_SCHEMA_REAUTH = vol.Schema(
) )
async def validate_input( async def validate_station(
hass: HomeAssistant, hass: HomeAssistant,
api_key: str, api_key: str,
train_from: str, train_station: str,
train_to: str, field: str,
train_time: str | None, ) -> tuple[list[StationInfoModel], dict[str, str]]:
weekdays: list[str],
product_filter: str | None,
) -> dict[str, str]:
"""Validate input from user input.""" """Validate input from user input."""
errors: dict[str, str] = {} errors: dict[str, str] = {}
stations = []
when = dt_util.now()
if train_time:
departure_day = next_departuredate(weekdays)
if _time := dt_util.parse_time(train_time):
when = datetime.combine(
departure_day,
_time,
dt_util.get_default_time_zone(),
)
try: try:
web_session = async_get_clientsession(hass) web_session = async_get_clientsession(hass)
train_api = TrafikverketTrain(web_session, api_key) train_api = TrafikverketTrain(web_session, api_key)
from_station = await train_api.async_search_train_station(train_from) stations = await train_api.async_search_train_stations(train_station)
to_station = await train_api.async_search_train_station(train_to)
if train_time:
await train_api.async_get_train_stop(
from_station, to_station, when, product_filter
)
else:
await train_api.async_get_next_train_stop(
from_station, to_station, when, product_filter
)
except InvalidAuthentication: except InvalidAuthentication:
errors["base"] = "invalid_auth" errors["base"] = "invalid_auth"
except NoTrainStationFound: except NoTrainStationFound:
errors["base"] = "invalid_station" errors[field] = "invalid_station"
except MultipleTrainStationsFound:
errors["base"] = "more_stations"
except NoTrainAnnouncementFound:
errors["base"] = "no_trains"
except UnknownError as error: except UnknownError as error:
_LOGGER.error("Unknown error occurred during validation %s", str(error)) _LOGGER.error("Unknown error occurred during validation %s", str(error))
errors["base"] = "cannot_connect" errors["base"] = "cannot_connect"
@ -118,14 +89,18 @@ async def validate_input(
_LOGGER.error("Unknown exception occurred during validation %s", str(error)) _LOGGER.error("Unknown exception occurred during validation %s", str(error))
errors["base"] = "cannot_connect" errors["base"] = "cannot_connect"
return errors return (stations, errors)
class TVTrainConfigFlow(ConfigFlow, domain=DOMAIN): class TVTrainConfigFlow(ConfigFlow, domain=DOMAIN):
"""Handle a config flow for Trafikverket Train integration.""" """Handle a config flow for Trafikverket Train integration."""
VERSION = 1 VERSION = 2
MINOR_VERSION = 2 MINOR_VERSION = 1
_from_stations: list[StationInfoModel]
_to_stations: list[StationInfoModel]
_data: dict[str, Any]
@staticmethod @staticmethod
@callback @callback
@ -151,14 +126,11 @@ class TVTrainConfigFlow(ConfigFlow, domain=DOMAIN):
api_key = user_input[CONF_API_KEY] api_key = user_input[CONF_API_KEY]
reauth_entry = self._get_reauth_entry() reauth_entry = self._get_reauth_entry()
errors = await validate_input( _, errors = await validate_station(
self.hass, self.hass,
api_key, api_key,
reauth_entry.data[CONF_FROM], reauth_entry.data[CONF_FROM],
reauth_entry.data[CONF_TO], CONF_FROM,
reauth_entry.data.get(CONF_TIME),
reauth_entry.data[CONF_WEEKDAY],
reauth_entry.options.get(CONF_FILTER_PRODUCT),
) )
if not errors: if not errors:
return self.async_update_reload_and_abort( return self.async_update_reload_and_abort(
@ -193,38 +165,40 @@ class TVTrainConfigFlow(ConfigFlow, domain=DOMAIN):
if train_time: if train_time:
name = f"{train_from} to {train_to} at {train_time}" name = f"{train_from} to {train_to} at {train_time}"
errors = await validate_input( self._from_stations, from_errors = await validate_station(
self.hass, self.hass, api_key, train_from, CONF_FROM
api_key,
train_from,
train_to,
train_time,
train_days,
filter_product,
) )
self._to_stations, to_errors = await validate_station(
self.hass, api_key, train_to, CONF_TO
)
errors = {**from_errors, **to_errors}
if not errors: if not errors:
self._async_abort_entries_match( if len(self._from_stations) == 1 and len(self._to_stations) == 1:
{ self._async_abort_entries_match(
CONF_API_KEY: api_key, {
CONF_FROM: train_from, CONF_API_KEY: api_key,
CONF_TO: train_to, CONF_FROM: self._from_stations[0].signature,
CONF_TIME: train_time, CONF_TO: self._to_stations[0].signature,
CONF_WEEKDAY: train_days, CONF_TIME: train_time,
CONF_FILTER_PRODUCT: filter_product, CONF_WEEKDAY: train_days,
} CONF_FILTER_PRODUCT: filter_product,
) }
return self.async_create_entry( )
title=name, return self.async_create_entry(
data={ title=name,
CONF_API_KEY: api_key, data={
CONF_NAME: name, CONF_API_KEY: api_key,
CONF_FROM: train_from, CONF_NAME: name,
CONF_TO: train_to, CONF_FROM: self._from_stations[0].signature,
CONF_TIME: train_time, CONF_TO: self._to_stations[0].signature,
CONF_WEEKDAY: train_days, CONF_TIME: train_time,
}, CONF_WEEKDAY: train_days,
options={CONF_FILTER_PRODUCT: filter_product}, },
) options={CONF_FILTER_PRODUCT: filter_product},
)
self._data = user_input
return await self.async_step_select_stations()
return self.async_show_form( return self.async_show_form(
step_id="user", step_id="user",
@ -234,6 +208,77 @@ class TVTrainConfigFlow(ConfigFlow, domain=DOMAIN):
errors=errors, errors=errors,
) )
async def async_step_select_stations(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Handle the select station step."""
if user_input is not None:
api_key: str = self._data[CONF_API_KEY]
train_from: str = user_input[CONF_FROM]
train_to: str = user_input[CONF_TO]
train_time: str | None = self._data.get(CONF_TIME)
train_days: list = self._data[CONF_WEEKDAY]
filter_product: str | None = self._data[CONF_FILTER_PRODUCT]
if filter_product == "":
filter_product = None
name = f"{self._data[CONF_FROM]} to {self._data[CONF_TO]}"
if train_time:
name = (
f"{self._data[CONF_FROM]} to {self._data[CONF_TO]} at {train_time}"
)
self._async_abort_entries_match(
{
CONF_API_KEY: api_key,
CONF_FROM: train_from,
CONF_TO: user_input[CONF_TO],
CONF_TIME: train_time,
CONF_WEEKDAY: train_days,
CONF_FILTER_PRODUCT: filter_product,
}
)
return self.async_create_entry(
title=name,
data={
CONF_API_KEY: api_key,
CONF_NAME: name,
CONF_FROM: train_from,
CONF_TO: train_to,
CONF_TIME: train_time,
CONF_WEEKDAY: train_days,
},
options={CONF_FILTER_PRODUCT: filter_product},
)
from_options = [
SelectOptionDict(value=station.signature, label=station.station_name)
for station in self._from_stations
]
to_options = [
SelectOptionDict(value=station.signature, label=station.station_name)
for station in self._to_stations
]
schema = {}
if len(from_options) > 1:
schema[vol.Required(CONF_FROM)] = SelectSelector(
SelectSelectorConfig(
options=from_options, mode=SelectSelectorMode.DROPDOWN, sort=True
)
)
if len(to_options) > 1:
schema[vol.Required(CONF_TO)] = SelectSelector(
SelectSelectorConfig(
options=to_options, mode=SelectSelectorMode.DROPDOWN, sort=True
)
)
return self.async_show_form(
step_id="select_stations",
data_schema=self.add_suggested_values_to_schema(
vol.Schema(schema), user_input or {}
),
)
class TVTrainOptionsFlowHandler(OptionsFlow): class TVTrainOptionsFlowHandler(OptionsFlow):
"""Handle Trafikverket Train options.""" """Handle Trafikverket Train options."""

View File

@ -7,15 +7,16 @@ from datetime import datetime, time, timedelta
import logging import logging
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from pytrafikverket import TrafikverketTrain from pytrafikverket import (
from pytrafikverket.exceptions import (
InvalidAuthentication, InvalidAuthentication,
MultipleTrainStationsFound, MultipleTrainStationsFound,
NoTrainAnnouncementFound, NoTrainAnnouncementFound,
NoTrainStationFound, NoTrainStationFound,
StationInfoModel,
TrafikverketTrain,
TrainStopModel,
UnknownError, UnknownError,
) )
from pytrafikverket.models import StationInfoModel, TrainStopModel
from homeassistant.const import CONF_API_KEY, CONF_WEEKDAY from homeassistant.const import CONF_API_KEY, CONF_WEEKDAY
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
@ -93,11 +94,15 @@ class TVDataUpdateCoordinator(DataUpdateCoordinator[TrainData]):
async def _async_setup(self) -> None: async def _async_setup(self) -> None:
"""Initiate stations.""" """Initiate stations."""
try: try:
self.to_station = await self._train_api.async_search_train_station( self.to_station = (
self.config_entry.data[CONF_TO] await self._train_api.async_get_train_station_from_signature(
self.config_entry.data[CONF_TO]
)
) )
self.from_station = await self._train_api.async_search_train_station( self.from_station = (
self.config_entry.data[CONF_FROM] await self._train_api.async_get_train_station_from_signature(
self.config_entry.data[CONF_FROM]
)
) )
except InvalidAuthentication as error: except InvalidAuthentication as error:
raise ConfigEntryAuthFailed from error raise ConfigEntryAuthFailed from error

View File

@ -27,6 +27,13 @@
"filter_product": "To filter by product description add the phrase here to match" "filter_product": "To filter by product description add the phrase here to match"
} }
}, },
"select_stations": {
"description": "More than one station was found with the provided name, select the correct ones from the provided lists",
"data": {
"to": "To station",
"from": "From station"
}
},
"reauth_confirm": { "reauth_confirm": {
"data": { "data": {
"api_key": "[%key:common::config_flow::data::api_key%]" "api_key": "[%key:common::config_flow::data::api_key%]"

View File

@ -6,7 +6,7 @@ from datetime import datetime, timedelta
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from pytrafikverket.models import TrainStopModel from pytrafikverket import StationInfoModel, TrainStopModel
from homeassistant.components.trafikverket_train.const import DOMAIN from homeassistant.components.trafikverket_train.const import DOMAIN
from homeassistant.config_entries import SOURCE_USER from homeassistant.config_entries import SOURCE_USER
@ -40,6 +40,9 @@ async def load_integration_from_entry(
patch( patch(
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_search_train_station", "homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_search_train_station",
), ),
patch(
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_get_train_station_from_signature",
),
): ):
await hass.config_entries.async_setup(config_entry_id) await hass.config_entries.async_setup(config_entry_id)
await hass.async_block_till_done() await hass.async_block_till_done()
@ -50,8 +53,8 @@ async def load_integration_from_entry(
data=ENTRY_CONFIG, data=ENTRY_CONFIG,
options=OPTIONS_CONFIG, options=OPTIONS_CONFIG,
entry_id="1", entry_id="1",
version=1, version=2,
minor_version=2, minor_version=1,
) )
config_entry.add_to_hass(hass) config_entry.add_to_hass(hass)
await setup_config_entry_with_mocked_data(config_entry.entry_id) await setup_config_entry_with_mocked_data(config_entry.entry_id)
@ -61,8 +64,8 @@ async def load_integration_from_entry(
source=SOURCE_USER, source=SOURCE_USER,
data=ENTRY_CONFIG2, data=ENTRY_CONFIG2,
entry_id="2", entry_id="2",
version=1, version=2,
minor_version=2, minor_version=1,
) )
config_entry2.add_to_hass(hass) config_entry2.add_to_hass(hass)
await setup_config_entry_with_mocked_data(config_entry2.entry_id) await setup_config_entry_with_mocked_data(config_entry2.entry_id)
@ -171,3 +174,57 @@ def fixture_get_train_stop() -> TrainStopModel:
modified_time=datetime(2023, 5, 1, 11, 0, tzinfo=dt_util.UTC), modified_time=datetime(2023, 5, 1, 11, 0, tzinfo=dt_util.UTC),
product_description=["Regionaltåg"], product_description=["Regionaltåg"],
) )
@pytest.fixture(name="get_train_stations")
def fixture_get_train_station() -> list[list[StationInfoModel]]:
"""Construct StationInfoModel Mock."""
return [
[
StationInfoModel(
signature="Cst",
station_name="Stockholm C",
advertised=True,
)
],
[
StationInfoModel(
signature="U",
station_name="Uppsala C",
advertised=True,
)
],
]
@pytest.fixture(name="get_multiple_train_stations")
def fixture_get_multiple_train_station() -> list[list[StationInfoModel]]:
"""Construct StationInfoModel Mock."""
return [
[
StationInfoModel(
signature="Cst",
station_name="Stockholm C",
advertised=True,
),
StationInfoModel(
signature="Csu",
station_name="Stockholm City",
advertised=True,
),
],
[
StationInfoModel(
signature="U",
station_name="Uppsala C",
advertised=True,
),
StationInfoModel(
signature="Ups",
station_name="Uppsala City",
advertised=True,
),
],
]

View File

@ -5,14 +5,13 @@ from __future__ import annotations
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from pytrafikverket.exceptions import ( from pytrafikverket import (
InvalidAuthentication, InvalidAuthentication,
MultipleTrainStationsFound,
NoTrainAnnouncementFound,
NoTrainStationFound, NoTrainStationFound,
StationInfoModel,
TrainStopModel,
UnknownError, UnknownError,
) )
from pytrafikverket.models import TrainStopModel
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.components.trafikverket_train.const import ( from homeassistant.components.trafikverket_train.const import (
@ -29,7 +28,9 @@ from homeassistant.data_entry_flow import FlowResultType
from tests.common import MockConfigEntry from tests.common import MockConfigEntry
async def test_form(hass: HomeAssistant) -> None: async def test_form(
hass: HomeAssistant, get_train_stations: list[StationInfoModel]
) -> None:
"""Test we get the form.""" """Test we get the form."""
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(
@ -40,10 +41,8 @@ async def test_form(hass: HomeAssistant) -> None:
with ( with (
patch( patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_search_train_station", "homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_search_train_stations",
), side_effect=get_train_stations,
patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_stop",
), ),
patch( patch(
"homeassistant.components.trafikverket_train.async_setup_entry", "homeassistant.components.trafikverket_train.async_setup_entry",
@ -67,8 +66,8 @@ async def test_form(hass: HomeAssistant) -> None:
assert result["data"] == { assert result["data"] == {
"api_key": "1234567890", "api_key": "1234567890",
"name": "Stockholm C to Uppsala C at 10:00", "name": "Stockholm C to Uppsala C at 10:00",
"from": "Stockholm C", "from": "Cst",
"to": "Uppsala C", "to": "U",
"time": "10:00", "time": "10:00",
"weekday": ["mon", "fri"], "weekday": ["mon", "fri"],
} }
@ -76,7 +75,70 @@ async def test_form(hass: HomeAssistant) -> None:
assert len(mock_setup_entry.mock_calls) == 1 assert len(mock_setup_entry.mock_calls) == 1
async def test_form_entry_already_exist(hass: HomeAssistant) -> None: async def test_form_multiple_stations(
hass: HomeAssistant, get_multiple_train_stations: list[StationInfoModel]
) -> None:
"""Test we get the form."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result["type"] is FlowResultType.FORM
assert result["errors"] == {}
with (
patch(
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_search_train_stations",
side_effect=get_multiple_train_stations,
),
):
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
CONF_API_KEY: "1234567890",
CONF_FROM: "Stockholm C",
CONF_TO: "Uppsala C",
CONF_TIME: "10:00",
CONF_WEEKDAY: ["mon", "fri"],
},
)
await hass.async_block_till_done()
with (
patch(
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_search_train_stations",
side_effect=get_multiple_train_stations,
),
patch(
"homeassistant.components.trafikverket_train.async_setup_entry",
return_value=True,
),
):
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
CONF_FROM: "Csu",
CONF_TO: "Ups",
},
)
await hass.async_block_till_done()
assert result["type"] is FlowResultType.CREATE_ENTRY
assert result["title"] == "Stockholm C to Uppsala C at 10:00"
assert result["data"] == {
"api_key": "1234567890",
"name": "Stockholm C to Uppsala C at 10:00",
"from": "Csu",
"to": "Ups",
"time": "10:00",
"weekday": ["mon", "fri"],
}
assert result["options"] == {"filter_product": None}
async def test_form_entry_already_exist(
hass: HomeAssistant, get_train_stations: list[StationInfoModel]
) -> None:
"""Test flow aborts when entry already exist.""" """Test flow aborts when entry already exist."""
entry = MockConfigEntry( entry = MockConfigEntry(
@ -84,14 +146,14 @@ async def test_form_entry_already_exist(hass: HomeAssistant) -> None:
data={ data={
CONF_API_KEY: "1234567890", CONF_API_KEY: "1234567890",
CONF_NAME: "Stockholm C to Uppsala C at 10:00", CONF_NAME: "Stockholm C to Uppsala C at 10:00",
CONF_FROM: "Stockholm C", CONF_FROM: "Cst",
CONF_TO: "Uppsala C", CONF_TO: "U",
CONF_TIME: "10:00", CONF_TIME: "10:00",
CONF_WEEKDAY: WEEKDAYS, CONF_WEEKDAY: WEEKDAYS,
CONF_FILTER_PRODUCT: None, CONF_FILTER_PRODUCT: None,
}, },
version=1, version=2,
minor_version=2, minor_version=1,
) )
entry.add_to_hass(hass) entry.add_to_hass(hass)
@ -103,10 +165,11 @@ async def test_form_entry_already_exist(hass: HomeAssistant) -> None:
with ( with (
patch( patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_search_train_station", "homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_stop",
), ),
patch( patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_stop", "homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_search_train_stations",
side_effect=get_train_stations,
), ),
patch( patch(
"homeassistant.components.trafikverket_train.async_setup_entry", "homeassistant.components.trafikverket_train.async_setup_entry",
@ -130,28 +193,24 @@ async def test_form_entry_already_exist(hass: HomeAssistant) -> None:
@pytest.mark.parametrize( @pytest.mark.parametrize(
("side_effect", "base_error"), ("side_effect", "p_error"),
[ [
( (
InvalidAuthentication, InvalidAuthentication,
"invalid_auth", {"base": "invalid_auth"},
), ),
( (
NoTrainStationFound, NoTrainStationFound,
"invalid_station", {"from": "invalid_station", "to": "invalid_station"},
),
(
MultipleTrainStationsFound,
"more_stations",
), ),
( (
Exception, Exception,
"cannot_connect", {"base": "cannot_connect"},
), ),
], ],
) )
async def test_flow_fails( async def test_flow_fails(
hass: HomeAssistant, side_effect: Exception, base_error: str hass: HomeAssistant, side_effect: Exception, p_error: dict[str, str]
) -> None: ) -> None:
"""Test config flow errors.""" """Test config flow errors."""
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(
@ -163,12 +222,9 @@ async def test_flow_fails(
with ( with (
patch( patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_search_train_station", "homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_search_train_stations",
side_effect=side_effect(), side_effect=side_effect(),
), ),
patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_stop",
),
): ):
result = await hass.config_entries.flow.async_configure( result = await hass.config_entries.flow.async_configure(
result["flow_id"], result["flow_id"],
@ -179,24 +235,24 @@ async def test_flow_fails(
}, },
) )
assert result["errors"] == {"base": base_error} assert result["errors"] == p_error
@pytest.mark.parametrize( @pytest.mark.parametrize(
("side_effect", "base_error"), ("side_effect", "p_error"),
[ [
( (
NoTrainAnnouncementFound, NoTrainStationFound,
"no_trains", {"from": "invalid_station", "to": "invalid_station"},
), ),
( (
UnknownError, UnknownError,
"cannot_connect", {"base": "cannot_connect"},
), ),
], ],
) )
async def test_flow_fails_departures( async def test_flow_fails_departures(
hass: HomeAssistant, side_effect: Exception, base_error: str hass: HomeAssistant, side_effect: Exception, p_error: dict[str, str]
) -> None: ) -> None:
"""Test config flow errors.""" """Test config flow errors."""
result = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(
@ -208,15 +264,9 @@ async def test_flow_fails_departures(
with ( with (
patch( patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_search_train_station", "homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_search_train_stations",
),
patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_next_train_stops",
side_effect=side_effect(), side_effect=side_effect(),
), ),
patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_stop",
),
): ):
result = await hass.config_entries.flow.async_configure( result = await hass.config_entries.flow.async_configure(
result["flow_id"], result["flow_id"],
@ -227,23 +277,25 @@ async def test_flow_fails_departures(
}, },
) )
assert result["errors"] == {"base": base_error} assert result["errors"] == p_error
async def test_reauth_flow(hass: HomeAssistant) -> None: async def test_reauth_flow(
hass: HomeAssistant, get_train_stations: list[StationInfoModel]
) -> None:
"""Test a reauthentication flow.""" """Test a reauthentication flow."""
entry = MockConfigEntry( entry = MockConfigEntry(
domain=DOMAIN, domain=DOMAIN,
data={ data={
CONF_API_KEY: "1234567890", CONF_API_KEY: "1234567890",
CONF_NAME: "Stockholm C to Uppsala C at 10:00", CONF_NAME: "Stockholm C to Uppsala C at 10:00",
CONF_FROM: "Stockholm C", CONF_FROM: "Cst",
CONF_TO: "Uppsala C", CONF_TO: "U",
CONF_TIME: "10:00", CONF_TIME: "10:00",
CONF_WEEKDAY: WEEKDAYS, CONF_WEEKDAY: WEEKDAYS,
}, },
version=1, version=2,
minor_version=2, minor_version=1,
) )
entry.add_to_hass(hass) entry.add_to_hass(hass)
@ -254,10 +306,8 @@ async def test_reauth_flow(hass: HomeAssistant) -> None:
with ( with (
patch( patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_search_train_station", "homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_search_train_stations",
), side_effect=get_train_stations,
patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_stop",
), ),
patch( patch(
"homeassistant.components.trafikverket_train.async_setup_entry", "homeassistant.components.trafikverket_train.async_setup_entry",
@ -275,8 +325,8 @@ async def test_reauth_flow(hass: HomeAssistant) -> None:
assert entry.data == { assert entry.data == {
"api_key": "1234567891", "api_key": "1234567891",
"name": "Stockholm C to Uppsala C at 10:00", "name": "Stockholm C to Uppsala C at 10:00",
"from": "Stockholm C", "from": "Cst",
"to": "Uppsala C", "to": "U",
"time": "10:00", "time": "10:00",
"weekday": ["mon", "tue", "wed", "thu", "fri", "sat", "sun"], "weekday": ["mon", "tue", "wed", "thu", "fri", "sat", "sun"],
} }
@ -287,24 +337,27 @@ async def test_reauth_flow(hass: HomeAssistant) -> None:
[ [
( (
InvalidAuthentication, InvalidAuthentication,
"invalid_auth", {"base": "invalid_auth"},
), ),
( (
NoTrainStationFound, NoTrainStationFound,
"invalid_station", {"from": "invalid_station"},
), ),
( (
MultipleTrainStationsFound, UnknownError,
"more_stations", {"base": "cannot_connect"},
), ),
( (
Exception, Exception,
"cannot_connect", {"base": "cannot_connect"},
), ),
], ],
) )
async def test_reauth_flow_error( async def test_reauth_flow_error(
hass: HomeAssistant, side_effect: Exception, p_error: str hass: HomeAssistant,
side_effect: Exception,
p_error: dict[str, str],
get_train_stations: list[StationInfoModel],
) -> None: ) -> None:
"""Test a reauthentication flow with error.""" """Test a reauthentication flow with error."""
entry = MockConfigEntry( entry = MockConfigEntry(
@ -312,13 +365,13 @@ async def test_reauth_flow_error(
data={ data={
CONF_API_KEY: "1234567890", CONF_API_KEY: "1234567890",
CONF_NAME: "Stockholm C to Uppsala C at 10:00", CONF_NAME: "Stockholm C to Uppsala C at 10:00",
CONF_FROM: "Stockholm C", CONF_FROM: "Cst",
CONF_TO: "Uppsala C", CONF_TO: "U",
CONF_TIME: "10:00", CONF_TIME: "10:00",
CONF_WEEKDAY: WEEKDAYS, CONF_WEEKDAY: WEEKDAYS,
}, },
version=1, version=2,
minor_version=2, minor_version=1,
) )
entry.add_to_hass(hass) entry.add_to_hass(hass)
@ -326,12 +379,9 @@ async def test_reauth_flow_error(
with ( with (
patch( patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_search_train_station", "homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_search_train_stations",
side_effect=side_effect(), side_effect=side_effect(),
), ),
patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_stop",
),
): ):
result = await hass.config_entries.flow.async_configure( result = await hass.config_entries.flow.async_configure(
result["flow_id"], result["flow_id"],
@ -341,14 +391,12 @@ async def test_reauth_flow_error(
assert result["step_id"] == "reauth_confirm" assert result["step_id"] == "reauth_confirm"
assert result["type"] is FlowResultType.FORM assert result["type"] is FlowResultType.FORM
assert result["errors"] == {"base": p_error} assert result["errors"] == p_error
with ( with (
patch( patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_search_train_station", "homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_search_train_stations",
), side_effect=get_train_stations,
patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_stop",
), ),
patch( patch(
"homeassistant.components.trafikverket_train.async_setup_entry", "homeassistant.components.trafikverket_train.async_setup_entry",
@ -366,8 +414,8 @@ async def test_reauth_flow_error(
assert entry.data == { assert entry.data == {
"api_key": "1234567891", "api_key": "1234567891",
"name": "Stockholm C to Uppsala C at 10:00", "name": "Stockholm C to Uppsala C at 10:00",
"from": "Stockholm C", "from": "Cst",
"to": "Uppsala C", "to": "U",
"time": "10:00", "time": "10:00",
"weekday": ["mon", "tue", "wed", "thu", "fri", "sat", "sun"], "weekday": ["mon", "tue", "wed", "thu", "fri", "sat", "sun"],
} }
@ -377,17 +425,20 @@ async def test_reauth_flow_error(
("side_effect", "p_error"), ("side_effect", "p_error"),
[ [
( (
NoTrainAnnouncementFound, NoTrainStationFound,
"no_trains", {"from": "invalid_station"},
), ),
( (
UnknownError, UnknownError,
"cannot_connect", {"base": "cannot_connect"},
), ),
], ],
) )
async def test_reauth_flow_error_departures( async def test_reauth_flow_error_departures(
hass: HomeAssistant, side_effect: Exception, p_error: str hass: HomeAssistant,
side_effect: Exception,
p_error: dict[str, str],
get_train_stations: list[StationInfoModel],
) -> None: ) -> None:
"""Test a reauthentication flow with error.""" """Test a reauthentication flow with error."""
entry = MockConfigEntry( entry = MockConfigEntry(
@ -395,13 +446,13 @@ async def test_reauth_flow_error_departures(
data={ data={
CONF_API_KEY: "1234567890", CONF_API_KEY: "1234567890",
CONF_NAME: "Stockholm C to Uppsala C at 10:00", CONF_NAME: "Stockholm C to Uppsala C at 10:00",
CONF_FROM: "Stockholm C", CONF_FROM: "Cst",
CONF_TO: "Uppsala C", CONF_TO: "U",
CONF_TIME: "10:00", CONF_TIME: "10:00",
CONF_WEEKDAY: WEEKDAYS, CONF_WEEKDAY: WEEKDAYS,
}, },
version=1, version=2,
minor_version=2, minor_version=1,
) )
entry.add_to_hass(hass) entry.add_to_hass(hass)
@ -409,10 +460,7 @@ async def test_reauth_flow_error_departures(
with ( with (
patch( patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_search_train_station", "homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_search_train_stations",
),
patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_stop",
side_effect=side_effect(), side_effect=side_effect(),
), ),
): ):
@ -424,11 +472,12 @@ async def test_reauth_flow_error_departures(
assert result["step_id"] == "reauth_confirm" assert result["step_id"] == "reauth_confirm"
assert result["type"] is FlowResultType.FORM assert result["type"] is FlowResultType.FORM
assert result["errors"] == {"base": p_error} assert result["errors"] == p_error
with ( with (
patch( patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_search_train_station", "homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_search_train_stations",
side_effect=get_train_stations,
), ),
patch( patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_stop", "homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_stop",
@ -449,8 +498,8 @@ async def test_reauth_flow_error_departures(
assert entry.data == { assert entry.data == {
"api_key": "1234567891", "api_key": "1234567891",
"name": "Stockholm C to Uppsala C at 10:00", "name": "Stockholm C to Uppsala C at 10:00",
"from": "Stockholm C", "from": "Cst",
"to": "Uppsala C", "to": "U",
"time": "10:00", "time": "10:00",
"weekday": ["mon", "tue", "wed", "thu", "fri", "sat", "sun"], "weekday": ["mon", "tue", "wed", "thu", "fri", "sat", "sun"],
} }
@ -460,6 +509,7 @@ async def test_options_flow(
hass: HomeAssistant, hass: HomeAssistant,
get_trains: list[TrainStopModel], get_trains: list[TrainStopModel],
get_train_stop: TrainStopModel, get_train_stop: TrainStopModel,
get_train_stations: list[StationInfoModel],
) -> None: ) -> None:
"""Test a reauthentication flow.""" """Test a reauthentication flow."""
entry = MockConfigEntry( entry = MockConfigEntry(
@ -467,24 +517,28 @@ async def test_options_flow(
data={ data={
CONF_API_KEY: "1234567890", CONF_API_KEY: "1234567890",
CONF_NAME: "Stockholm C to Uppsala C at 10:00", CONF_NAME: "Stockholm C to Uppsala C at 10:00",
CONF_FROM: "Stockholm C", CONF_FROM: "Cst",
CONF_TO: "Uppsala C", CONF_TO: "U",
CONF_TIME: "10:00", CONF_TIME: "10:00",
CONF_WEEKDAY: WEEKDAYS, CONF_WEEKDAY: WEEKDAYS,
}, },
version=1, version=2,
minor_version=2, minor_version=1,
) )
entry.add_to_hass(hass) entry.add_to_hass(hass)
with ( with (
patch( patch(
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_search_train_station", "homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_search_train_stations",
side_effect=get_train_stations,
), ),
patch( patch(
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_get_next_train_stops", "homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_get_next_train_stops",
return_value=get_trains, return_value=get_trains,
), ),
patch(
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_get_train_station_from_signature",
),
patch( patch(
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_get_train_stop", "homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_get_train_stop",
return_value=get_train_stop, return_value=get_train_stop,

View File

@ -4,8 +4,14 @@ from __future__ import annotations
from unittest.mock import patch from unittest.mock import patch
from pytrafikverket.exceptions import InvalidAuthentication, NoTrainStationFound import pytest
from pytrafikverket.models import TrainStopModel from pytrafikverket import (
InvalidAuthentication,
NoTrainStationFound,
StationInfoModel,
TrainStopModel,
UnknownError,
)
from syrupy.assertion import SnapshotAssertion from syrupy.assertion import SnapshotAssertion
from homeassistant.components.trafikverket_train.const import DOMAIN from homeassistant.components.trafikverket_train.const import DOMAIN
@ -28,14 +34,14 @@ async def test_unload_entry(
data=ENTRY_CONFIG, data=ENTRY_CONFIG,
options=OPTIONS_CONFIG, options=OPTIONS_CONFIG,
entry_id="1", entry_id="1",
version=1, version=2,
minor_version=2, minor_version=1,
) )
entry.add_to_hass(hass) entry.add_to_hass(hass)
with ( with (
patch( patch(
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_search_train_station", "homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_get_train_station_from_signature",
), ),
patch( patch(
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_get_next_train_stops", "homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_get_next_train_stops",
@ -65,13 +71,13 @@ async def test_auth_failed(
data=ENTRY_CONFIG, data=ENTRY_CONFIG,
options=OPTIONS_CONFIG, options=OPTIONS_CONFIG,
entry_id="1", entry_id="1",
version=1, version=2,
minor_version=2, minor_version=1,
) )
entry.add_to_hass(hass) entry.add_to_hass(hass)
with patch( with patch(
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_search_train_station", "homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_get_train_station_from_signature",
side_effect=InvalidAuthentication, side_effect=InvalidAuthentication,
): ):
await hass.config_entries.async_setup(entry.entry_id) await hass.config_entries.async_setup(entry.entry_id)
@ -96,13 +102,13 @@ async def test_no_stations(
data=ENTRY_CONFIG, data=ENTRY_CONFIG,
options=OPTIONS_CONFIG, options=OPTIONS_CONFIG,
entry_id="1", entry_id="1",
version=1, version=2,
minor_version=2, minor_version=1,
) )
entry.add_to_hass(hass) entry.add_to_hass(hass)
with patch( with patch(
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_search_train_station", "homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_get_train_station_from_signature",
side_effect=NoTrainStationFound, side_effect=NoTrainStationFound,
): ):
await hass.config_entries.async_setup(entry.entry_id) await hass.config_entries.async_setup(entry.entry_id)
@ -124,8 +130,8 @@ async def test_migrate_entity_unique_id(
data=ENTRY_CONFIG, data=ENTRY_CONFIG,
options=OPTIONS_CONFIG, options=OPTIONS_CONFIG,
entry_id="1", entry_id="1",
version=1, version=2,
minor_version=2, minor_version=1,
) )
entry.add_to_hass(hass) entry.add_to_hass(hass)
@ -139,7 +145,7 @@ async def test_migrate_entity_unique_id(
with ( with (
patch( patch(
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_search_train_station", "homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_get_train_station_from_signature",
), ),
patch( patch(
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_get_next_train_stops", "homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_get_next_train_stops",
@ -158,8 +164,9 @@ async def test_migrate_entity_unique_id(
async def test_migrate_entry( async def test_migrate_entry(
hass: HomeAssistant, hass: HomeAssistant,
get_trains: list[TrainStopModel], get_trains: list[TrainStopModel],
get_train_stations: list[StationInfoModel],
) -> None: ) -> None:
"""Test migrate entry unique id.""" """Test migrate entry."""
entry = MockConfigEntry( entry = MockConfigEntry(
domain=DOMAIN, domain=DOMAIN,
source=SOURCE_USER, source=SOURCE_USER,
@ -174,7 +181,11 @@ async def test_migrate_entry(
with ( with (
patch( patch(
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_search_train_station", "homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_get_train_station_from_signature",
),
patch(
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_search_train_stations",
side_effect=get_train_stations,
), ),
patch( patch(
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_get_next_train_stops", "homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_get_next_train_stops",
@ -186,8 +197,18 @@ async def test_migrate_entry(
assert entry.state is ConfigEntryState.LOADED assert entry.state is ConfigEntryState.LOADED
assert entry.version == 1 assert entry.version == 2
assert entry.minor_version == 2 assert entry.minor_version == 1
# Migration to version 2.1 changed from/to to use station signatures
assert entry.data == {
"api_key": "1234567890",
"from": "Cst",
"to": "U",
"time": None,
"weekday": ["mon", "tue", "wed", "thu", "fri", "sat", "sun"],
"name": "Stockholm C to Uppsala C",
}
# Migration to version 1.2 removed unique_id
assert entry.unique_id is None assert entry.unique_id is None
@ -201,18 +222,73 @@ async def test_migrate_entry_from_future_version_fails(
source=SOURCE_USER, source=SOURCE_USER,
data=ENTRY_CONFIG, data=ENTRY_CONFIG,
options=OPTIONS_CONFIG, options=OPTIONS_CONFIG,
version=2, version=3,
minor_version=1,
entry_id="1",
)
entry.add_to_hass(hass)
await hass.config_entries.async_setup(entry.entry_id)
await hass.async_block_till_done()
assert entry.state is ConfigEntryState.MIGRATION_ERROR
@pytest.mark.parametrize(
("side_effect"),
[
(InvalidAuthentication),
(NoTrainStationFound),
(UnknownError),
(Exception),
],
)
async def test_migrate_entry_fails(hass: HomeAssistant, side_effect: Exception) -> None:
"""Test migrate entry fails."""
entry = MockConfigEntry(
domain=DOMAIN,
source=SOURCE_USER,
data=ENTRY_CONFIG,
options=OPTIONS_CONFIG,
version=1,
minor_version=1,
entry_id="1", entry_id="1",
) )
entry.add_to_hass(hass) entry.add_to_hass(hass)
with ( with (
patch( patch(
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_search_train_station", "homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_search_train_stations",
), side_effect=side_effect(),
patch( ),
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_get_next_train_stops", ):
return_value=get_trains, await hass.config_entries.async_setup(entry.entry_id)
await hass.async_block_till_done()
assert entry.state is ConfigEntryState.MIGRATION_ERROR
async def test_migrate_entry_fails_multiple_stations(
hass: HomeAssistant,
get_multiple_train_stations: list[StationInfoModel],
) -> None:
"""Test migrate entry fails on multiple stations found."""
entry = MockConfigEntry(
domain=DOMAIN,
source=SOURCE_USER,
data=ENTRY_CONFIG,
options=OPTIONS_CONFIG,
version=1,
minor_version=1,
entry_id="1",
unique_id="321",
)
entry.add_to_hass(hass)
with (
patch(
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_search_train_stations",
side_effect=get_multiple_train_stations,
), ),
): ):
await hass.config_entries.async_setup(entry.entry_id) await hass.config_entries.async_setup(entry.entry_id)