Change Trafikverket Train to use station signatures (#131416)

Co-authored-by: Robert Resch <robert@resch.dev>
This commit is contained in:
G Johansson 2025-01-13 15:38:02 +01:00 committed by GitHub
parent 157548609b
commit 4709a3162c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 516 additions and 220 deletions

View File

@ -4,11 +4,21 @@ from __future__ import annotations
import logging
from homeassistant.config_entries import ConfigEntry
from homeassistant.core import HomeAssistant
from homeassistant.helpers import entity_registry as er
from pytrafikverket import (
InvalidAuthentication,
NoTrainStationFound,
TrafikverketTrain,
UnknownError,
)
from .const import PLATFORMS
from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_API_KEY
from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryAuthFailed
from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.aiohttp_client import async_get_clientsession
from .const import CONF_FROM, CONF_TO, PLATFORMS
from .coordinator import TVDataUpdateCoordinator
TVTrainConfigEntry = ConfigEntry[TVDataUpdateCoordinator]
@ -52,13 +62,55 @@ async def async_migrate_entry(hass: HomeAssistant, entry: TVTrainConfigEntry) ->
"""Migrate config entry."""
_LOGGER.debug("Migrating from version %s", entry.version)
if entry.version > 1:
if entry.version > 2:
# This means the user has downgraded from a future version
return False
if entry.version == 1 and entry.minor_version == 1:
# Remove unique id
hass.config_entries.async_update_entry(entry, unique_id=None, minor_version=2)
if entry.version == 1:
if entry.minor_version == 1:
# Remove unique id
hass.config_entries.async_update_entry(
entry, unique_id=None, minor_version=2
)
# Change from station names to station signatures
try:
web_session = async_get_clientsession(hass)
train_api = TrafikverketTrain(web_session, entry.data[CONF_API_KEY])
from_stations = await train_api.async_search_train_stations(
entry.data[CONF_FROM]
)
to_stations = await train_api.async_search_train_stations(
entry.data[CONF_TO]
)
except InvalidAuthentication as error:
raise ConfigEntryAuthFailed from error
except NoTrainStationFound as error:
_LOGGER.error(
"Migration failed as no train station found with provided name %s",
str(error),
)
return False
except UnknownError as error:
_LOGGER.error("Unknown error occurred during validation %s", str(error))
return False
except Exception as error: # noqa: BLE001
_LOGGER.error("Unknown exception occurred during validation %s", str(error))
return False
if len(from_stations) > 1 or len(to_stations) > 1:
_LOGGER.error(
"Migration failed as more than one station found with provided name"
)
return False
new_data = entry.data.copy()
new_data[CONF_FROM] = from_stations[0].signature
new_data[CONF_TO] = to_stations[0].signature
hass.config_entries.async_update_entry(
entry, data=new_data, version=2, minor_version=1
)
_LOGGER.debug(
"Migration to version %s.%s successful",

View File

@ -3,16 +3,14 @@
from __future__ import annotations
from collections.abc import Mapping
from datetime import datetime
import logging
from typing import Any
from pytrafikverket import TrafikverketTrain
from pytrafikverket.exceptions import (
from pytrafikverket import (
InvalidAuthentication,
MultipleTrainStationsFound,
NoTrainAnnouncementFound,
NoTrainStationFound,
StationInfoModel,
TrafikverketTrain,
UnknownError,
)
import voluptuous as vol
@ -28,16 +26,15 @@ from homeassistant.core import HomeAssistant, callback
from homeassistant.helpers.aiohttp_client import async_get_clientsession
import homeassistant.helpers.config_validation as cv
from homeassistant.helpers.selector import (
SelectOptionDict,
SelectSelector,
SelectSelectorConfig,
SelectSelectorMode,
TextSelector,
TimeSelector,
)
import homeassistant.util.dt as dt_util
from .const import CONF_FILTER_PRODUCT, CONF_FROM, CONF_TIME, CONF_TO, DOMAIN
from .util import next_departuredate
_LOGGER = logging.getLogger(__name__)
@ -68,49 +65,23 @@ DATA_SCHEMA_REAUTH = vol.Schema(
)
async def validate_input(
async def validate_station(
hass: HomeAssistant,
api_key: str,
train_from: str,
train_to: str,
train_time: str | None,
weekdays: list[str],
product_filter: str | None,
) -> dict[str, str]:
train_station: str,
field: str,
) -> tuple[list[StationInfoModel], dict[str, str]]:
"""Validate input from user input."""
errors: dict[str, str] = {}
when = dt_util.now()
if train_time:
departure_day = next_departuredate(weekdays)
if _time := dt_util.parse_time(train_time):
when = datetime.combine(
departure_day,
_time,
dt_util.get_default_time_zone(),
)
stations = []
try:
web_session = async_get_clientsession(hass)
train_api = TrafikverketTrain(web_session, api_key)
from_station = await train_api.async_search_train_station(train_from)
to_station = await train_api.async_search_train_station(train_to)
if train_time:
await train_api.async_get_train_stop(
from_station, to_station, when, product_filter
)
else:
await train_api.async_get_next_train_stop(
from_station, to_station, when, product_filter
)
stations = await train_api.async_search_train_stations(train_station)
except InvalidAuthentication:
errors["base"] = "invalid_auth"
except NoTrainStationFound:
errors["base"] = "invalid_station"
except MultipleTrainStationsFound:
errors["base"] = "more_stations"
except NoTrainAnnouncementFound:
errors["base"] = "no_trains"
errors[field] = "invalid_station"
except UnknownError as error:
_LOGGER.error("Unknown error occurred during validation %s", str(error))
errors["base"] = "cannot_connect"
@ -118,14 +89,18 @@ async def validate_input(
_LOGGER.error("Unknown exception occurred during validation %s", str(error))
errors["base"] = "cannot_connect"
return errors
return (stations, errors)
class TVTrainConfigFlow(ConfigFlow, domain=DOMAIN):
"""Handle a config flow for Trafikverket Train integration."""
VERSION = 1
MINOR_VERSION = 2
VERSION = 2
MINOR_VERSION = 1
_from_stations: list[StationInfoModel]
_to_stations: list[StationInfoModel]
_data: dict[str, Any]
@staticmethod
@callback
@ -151,14 +126,11 @@ class TVTrainConfigFlow(ConfigFlow, domain=DOMAIN):
api_key = user_input[CONF_API_KEY]
reauth_entry = self._get_reauth_entry()
errors = await validate_input(
_, errors = await validate_station(
self.hass,
api_key,
reauth_entry.data[CONF_FROM],
reauth_entry.data[CONF_TO],
reauth_entry.data.get(CONF_TIME),
reauth_entry.data[CONF_WEEKDAY],
reauth_entry.options.get(CONF_FILTER_PRODUCT),
CONF_FROM,
)
if not errors:
return self.async_update_reload_and_abort(
@ -193,38 +165,40 @@ class TVTrainConfigFlow(ConfigFlow, domain=DOMAIN):
if train_time:
name = f"{train_from} to {train_to} at {train_time}"
errors = await validate_input(
self.hass,
api_key,
train_from,
train_to,
train_time,
train_days,
filter_product,
self._from_stations, from_errors = await validate_station(
self.hass, api_key, train_from, CONF_FROM
)
self._to_stations, to_errors = await validate_station(
self.hass, api_key, train_to, CONF_TO
)
errors = {**from_errors, **to_errors}
if not errors:
self._async_abort_entries_match(
{
CONF_API_KEY: api_key,
CONF_FROM: train_from,
CONF_TO: train_to,
CONF_TIME: train_time,
CONF_WEEKDAY: train_days,
CONF_FILTER_PRODUCT: filter_product,
}
)
return self.async_create_entry(
title=name,
data={
CONF_API_KEY: api_key,
CONF_NAME: name,
CONF_FROM: train_from,
CONF_TO: train_to,
CONF_TIME: train_time,
CONF_WEEKDAY: train_days,
},
options={CONF_FILTER_PRODUCT: filter_product},
)
if len(self._from_stations) == 1 and len(self._to_stations) == 1:
self._async_abort_entries_match(
{
CONF_API_KEY: api_key,
CONF_FROM: self._from_stations[0].signature,
CONF_TO: self._to_stations[0].signature,
CONF_TIME: train_time,
CONF_WEEKDAY: train_days,
CONF_FILTER_PRODUCT: filter_product,
}
)
return self.async_create_entry(
title=name,
data={
CONF_API_KEY: api_key,
CONF_NAME: name,
CONF_FROM: self._from_stations[0].signature,
CONF_TO: self._to_stations[0].signature,
CONF_TIME: train_time,
CONF_WEEKDAY: train_days,
},
options={CONF_FILTER_PRODUCT: filter_product},
)
self._data = user_input
return await self.async_step_select_stations()
return self.async_show_form(
step_id="user",
@ -234,6 +208,77 @@ class TVTrainConfigFlow(ConfigFlow, domain=DOMAIN):
errors=errors,
)
async def async_step_select_stations(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Handle the select station step."""
if user_input is not None:
api_key: str = self._data[CONF_API_KEY]
train_from: str = user_input[CONF_FROM]
train_to: str = user_input[CONF_TO]
train_time: str | None = self._data.get(CONF_TIME)
train_days: list = self._data[CONF_WEEKDAY]
filter_product: str | None = self._data[CONF_FILTER_PRODUCT]
if filter_product == "":
filter_product = None
name = f"{self._data[CONF_FROM]} to {self._data[CONF_TO]}"
if train_time:
name = (
f"{self._data[CONF_FROM]} to {self._data[CONF_TO]} at {train_time}"
)
self._async_abort_entries_match(
{
CONF_API_KEY: api_key,
CONF_FROM: train_from,
CONF_TO: user_input[CONF_TO],
CONF_TIME: train_time,
CONF_WEEKDAY: train_days,
CONF_FILTER_PRODUCT: filter_product,
}
)
return self.async_create_entry(
title=name,
data={
CONF_API_KEY: api_key,
CONF_NAME: name,
CONF_FROM: train_from,
CONF_TO: train_to,
CONF_TIME: train_time,
CONF_WEEKDAY: train_days,
},
options={CONF_FILTER_PRODUCT: filter_product},
)
from_options = [
SelectOptionDict(value=station.signature, label=station.station_name)
for station in self._from_stations
]
to_options = [
SelectOptionDict(value=station.signature, label=station.station_name)
for station in self._to_stations
]
schema = {}
if len(from_options) > 1:
schema[vol.Required(CONF_FROM)] = SelectSelector(
SelectSelectorConfig(
options=from_options, mode=SelectSelectorMode.DROPDOWN, sort=True
)
)
if len(to_options) > 1:
schema[vol.Required(CONF_TO)] = SelectSelector(
SelectSelectorConfig(
options=to_options, mode=SelectSelectorMode.DROPDOWN, sort=True
)
)
return self.async_show_form(
step_id="select_stations",
data_schema=self.add_suggested_values_to_schema(
vol.Schema(schema), user_input or {}
),
)
class TVTrainOptionsFlowHandler(OptionsFlow):
"""Handle Trafikverket Train options."""

View File

@ -7,15 +7,16 @@ from datetime import datetime, time, timedelta
import logging
from typing import TYPE_CHECKING
from pytrafikverket import TrafikverketTrain
from pytrafikverket.exceptions import (
from pytrafikverket import (
InvalidAuthentication,
MultipleTrainStationsFound,
NoTrainAnnouncementFound,
NoTrainStationFound,
StationInfoModel,
TrafikverketTrain,
TrainStopModel,
UnknownError,
)
from pytrafikverket.models import StationInfoModel, TrainStopModel
from homeassistant.const import CONF_API_KEY, CONF_WEEKDAY
from homeassistant.core import HomeAssistant
@ -93,11 +94,15 @@ class TVDataUpdateCoordinator(DataUpdateCoordinator[TrainData]):
async def _async_setup(self) -> None:
"""Initiate stations."""
try:
self.to_station = await self._train_api.async_search_train_station(
self.config_entry.data[CONF_TO]
self.to_station = (
await self._train_api.async_get_train_station_from_signature(
self.config_entry.data[CONF_TO]
)
)
self.from_station = await self._train_api.async_search_train_station(
self.config_entry.data[CONF_FROM]
self.from_station = (
await self._train_api.async_get_train_station_from_signature(
self.config_entry.data[CONF_FROM]
)
)
except InvalidAuthentication as error:
raise ConfigEntryAuthFailed from error

View File

@ -27,6 +27,13 @@
"filter_product": "To filter by product description add the phrase here to match"
}
},
"select_stations": {
"description": "More than one station was found with the provided name, select the correct ones from the provided lists",
"data": {
"to": "To station",
"from": "From station"
}
},
"reauth_confirm": {
"data": {
"api_key": "[%key:common::config_flow::data::api_key%]"

View File

@ -6,7 +6,7 @@ from datetime import datetime, timedelta
from unittest.mock import patch
import pytest
from pytrafikverket.models import TrainStopModel
from pytrafikverket import StationInfoModel, TrainStopModel
from homeassistant.components.trafikverket_train.const import DOMAIN
from homeassistant.config_entries import SOURCE_USER
@ -40,6 +40,9 @@ async def load_integration_from_entry(
patch(
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_search_train_station",
),
patch(
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_get_train_station_from_signature",
),
):
await hass.config_entries.async_setup(config_entry_id)
await hass.async_block_till_done()
@ -50,8 +53,8 @@ async def load_integration_from_entry(
data=ENTRY_CONFIG,
options=OPTIONS_CONFIG,
entry_id="1",
version=1,
minor_version=2,
version=2,
minor_version=1,
)
config_entry.add_to_hass(hass)
await setup_config_entry_with_mocked_data(config_entry.entry_id)
@ -61,8 +64,8 @@ async def load_integration_from_entry(
source=SOURCE_USER,
data=ENTRY_CONFIG2,
entry_id="2",
version=1,
minor_version=2,
version=2,
minor_version=1,
)
config_entry2.add_to_hass(hass)
await setup_config_entry_with_mocked_data(config_entry2.entry_id)
@ -171,3 +174,57 @@ def fixture_get_train_stop() -> TrainStopModel:
modified_time=datetime(2023, 5, 1, 11, 0, tzinfo=dt_util.UTC),
product_description=["Regionaltåg"],
)
@pytest.fixture(name="get_train_stations")
def fixture_get_train_station() -> list[list[StationInfoModel]]:
"""Construct StationInfoModel Mock."""
return [
[
StationInfoModel(
signature="Cst",
station_name="Stockholm C",
advertised=True,
)
],
[
StationInfoModel(
signature="U",
station_name="Uppsala C",
advertised=True,
)
],
]
@pytest.fixture(name="get_multiple_train_stations")
def fixture_get_multiple_train_station() -> list[list[StationInfoModel]]:
"""Construct StationInfoModel Mock."""
return [
[
StationInfoModel(
signature="Cst",
station_name="Stockholm C",
advertised=True,
),
StationInfoModel(
signature="Csu",
station_name="Stockholm City",
advertised=True,
),
],
[
StationInfoModel(
signature="U",
station_name="Uppsala C",
advertised=True,
),
StationInfoModel(
signature="Ups",
station_name="Uppsala City",
advertised=True,
),
],
]

View File

@ -5,14 +5,13 @@ from __future__ import annotations
from unittest.mock import patch
import pytest
from pytrafikverket.exceptions import (
from pytrafikverket import (
InvalidAuthentication,
MultipleTrainStationsFound,
NoTrainAnnouncementFound,
NoTrainStationFound,
StationInfoModel,
TrainStopModel,
UnknownError,
)
from pytrafikverket.models import TrainStopModel
from homeassistant import config_entries
from homeassistant.components.trafikverket_train.const import (
@ -29,7 +28,9 @@ from homeassistant.data_entry_flow import FlowResultType
from tests.common import MockConfigEntry
async def test_form(hass: HomeAssistant) -> None:
async def test_form(
hass: HomeAssistant, get_train_stations: list[StationInfoModel]
) -> None:
"""Test we get the form."""
result = await hass.config_entries.flow.async_init(
@ -40,10 +41,8 @@ async def test_form(hass: HomeAssistant) -> None:
with (
patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_search_train_station",
),
patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_stop",
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_search_train_stations",
side_effect=get_train_stations,
),
patch(
"homeassistant.components.trafikverket_train.async_setup_entry",
@ -67,8 +66,8 @@ async def test_form(hass: HomeAssistant) -> None:
assert result["data"] == {
"api_key": "1234567890",
"name": "Stockholm C to Uppsala C at 10:00",
"from": "Stockholm C",
"to": "Uppsala C",
"from": "Cst",
"to": "U",
"time": "10:00",
"weekday": ["mon", "fri"],
}
@ -76,7 +75,70 @@ async def test_form(hass: HomeAssistant) -> None:
assert len(mock_setup_entry.mock_calls) == 1
async def test_form_entry_already_exist(hass: HomeAssistant) -> None:
async def test_form_multiple_stations(
hass: HomeAssistant, get_multiple_train_stations: list[StationInfoModel]
) -> None:
"""Test we get the form."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result["type"] is FlowResultType.FORM
assert result["errors"] == {}
with (
patch(
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_search_train_stations",
side_effect=get_multiple_train_stations,
),
):
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
CONF_API_KEY: "1234567890",
CONF_FROM: "Stockholm C",
CONF_TO: "Uppsala C",
CONF_TIME: "10:00",
CONF_WEEKDAY: ["mon", "fri"],
},
)
await hass.async_block_till_done()
with (
patch(
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_search_train_stations",
side_effect=get_multiple_train_stations,
),
patch(
"homeassistant.components.trafikverket_train.async_setup_entry",
return_value=True,
),
):
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
CONF_FROM: "Csu",
CONF_TO: "Ups",
},
)
await hass.async_block_till_done()
assert result["type"] is FlowResultType.CREATE_ENTRY
assert result["title"] == "Stockholm C to Uppsala C at 10:00"
assert result["data"] == {
"api_key": "1234567890",
"name": "Stockholm C to Uppsala C at 10:00",
"from": "Csu",
"to": "Ups",
"time": "10:00",
"weekday": ["mon", "fri"],
}
assert result["options"] == {"filter_product": None}
async def test_form_entry_already_exist(
hass: HomeAssistant, get_train_stations: list[StationInfoModel]
) -> None:
"""Test flow aborts when entry already exist."""
entry = MockConfigEntry(
@ -84,14 +146,14 @@ async def test_form_entry_already_exist(hass: HomeAssistant) -> None:
data={
CONF_API_KEY: "1234567890",
CONF_NAME: "Stockholm C to Uppsala C at 10:00",
CONF_FROM: "Stockholm C",
CONF_TO: "Uppsala C",
CONF_FROM: "Cst",
CONF_TO: "U",
CONF_TIME: "10:00",
CONF_WEEKDAY: WEEKDAYS,
CONF_FILTER_PRODUCT: None,
},
version=1,
minor_version=2,
version=2,
minor_version=1,
)
entry.add_to_hass(hass)
@ -103,10 +165,11 @@ async def test_form_entry_already_exist(hass: HomeAssistant) -> None:
with (
patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_search_train_station",
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_stop",
),
patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_stop",
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_search_train_stations",
side_effect=get_train_stations,
),
patch(
"homeassistant.components.trafikverket_train.async_setup_entry",
@ -130,28 +193,24 @@ async def test_form_entry_already_exist(hass: HomeAssistant) -> None:
@pytest.mark.parametrize(
("side_effect", "base_error"),
("side_effect", "p_error"),
[
(
InvalidAuthentication,
"invalid_auth",
{"base": "invalid_auth"},
),
(
NoTrainStationFound,
"invalid_station",
),
(
MultipleTrainStationsFound,
"more_stations",
{"from": "invalid_station", "to": "invalid_station"},
),
(
Exception,
"cannot_connect",
{"base": "cannot_connect"},
),
],
)
async def test_flow_fails(
hass: HomeAssistant, side_effect: Exception, base_error: str
hass: HomeAssistant, side_effect: Exception, p_error: dict[str, str]
) -> None:
"""Test config flow errors."""
result = await hass.config_entries.flow.async_init(
@ -163,12 +222,9 @@ async def test_flow_fails(
with (
patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_search_train_station",
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_search_train_stations",
side_effect=side_effect(),
),
patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_stop",
),
):
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
@ -179,24 +235,24 @@ async def test_flow_fails(
},
)
assert result["errors"] == {"base": base_error}
assert result["errors"] == p_error
@pytest.mark.parametrize(
("side_effect", "base_error"),
("side_effect", "p_error"),
[
(
NoTrainAnnouncementFound,
"no_trains",
NoTrainStationFound,
{"from": "invalid_station", "to": "invalid_station"},
),
(
UnknownError,
"cannot_connect",
{"base": "cannot_connect"},
),
],
)
async def test_flow_fails_departures(
hass: HomeAssistant, side_effect: Exception, base_error: str
hass: HomeAssistant, side_effect: Exception, p_error: dict[str, str]
) -> None:
"""Test config flow errors."""
result = await hass.config_entries.flow.async_init(
@ -208,15 +264,9 @@ async def test_flow_fails_departures(
with (
patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_search_train_station",
),
patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_next_train_stops",
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_search_train_stations",
side_effect=side_effect(),
),
patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_stop",
),
):
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
@ -227,23 +277,25 @@ async def test_flow_fails_departures(
},
)
assert result["errors"] == {"base": base_error}
assert result["errors"] == p_error
async def test_reauth_flow(hass: HomeAssistant) -> None:
async def test_reauth_flow(
hass: HomeAssistant, get_train_stations: list[StationInfoModel]
) -> None:
"""Test a reauthentication flow."""
entry = MockConfigEntry(
domain=DOMAIN,
data={
CONF_API_KEY: "1234567890",
CONF_NAME: "Stockholm C to Uppsala C at 10:00",
CONF_FROM: "Stockholm C",
CONF_TO: "Uppsala C",
CONF_FROM: "Cst",
CONF_TO: "U",
CONF_TIME: "10:00",
CONF_WEEKDAY: WEEKDAYS,
},
version=1,
minor_version=2,
version=2,
minor_version=1,
)
entry.add_to_hass(hass)
@ -254,10 +306,8 @@ async def test_reauth_flow(hass: HomeAssistant) -> None:
with (
patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_search_train_station",
),
patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_stop",
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_search_train_stations",
side_effect=get_train_stations,
),
patch(
"homeassistant.components.trafikverket_train.async_setup_entry",
@ -275,8 +325,8 @@ async def test_reauth_flow(hass: HomeAssistant) -> None:
assert entry.data == {
"api_key": "1234567891",
"name": "Stockholm C to Uppsala C at 10:00",
"from": "Stockholm C",
"to": "Uppsala C",
"from": "Cst",
"to": "U",
"time": "10:00",
"weekday": ["mon", "tue", "wed", "thu", "fri", "sat", "sun"],
}
@ -287,24 +337,27 @@ async def test_reauth_flow(hass: HomeAssistant) -> None:
[
(
InvalidAuthentication,
"invalid_auth",
{"base": "invalid_auth"},
),
(
NoTrainStationFound,
"invalid_station",
{"from": "invalid_station"},
),
(
MultipleTrainStationsFound,
"more_stations",
UnknownError,
{"base": "cannot_connect"},
),
(
Exception,
"cannot_connect",
{"base": "cannot_connect"},
),
],
)
async def test_reauth_flow_error(
hass: HomeAssistant, side_effect: Exception, p_error: str
hass: HomeAssistant,
side_effect: Exception,
p_error: dict[str, str],
get_train_stations: list[StationInfoModel],
) -> None:
"""Test a reauthentication flow with error."""
entry = MockConfigEntry(
@ -312,13 +365,13 @@ async def test_reauth_flow_error(
data={
CONF_API_KEY: "1234567890",
CONF_NAME: "Stockholm C to Uppsala C at 10:00",
CONF_FROM: "Stockholm C",
CONF_TO: "Uppsala C",
CONF_FROM: "Cst",
CONF_TO: "U",
CONF_TIME: "10:00",
CONF_WEEKDAY: WEEKDAYS,
},
version=1,
minor_version=2,
version=2,
minor_version=1,
)
entry.add_to_hass(hass)
@ -326,12 +379,9 @@ async def test_reauth_flow_error(
with (
patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_search_train_station",
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_search_train_stations",
side_effect=side_effect(),
),
patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_stop",
),
):
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
@ -341,14 +391,12 @@ async def test_reauth_flow_error(
assert result["step_id"] == "reauth_confirm"
assert result["type"] is FlowResultType.FORM
assert result["errors"] == {"base": p_error}
assert result["errors"] == p_error
with (
patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_search_train_station",
),
patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_stop",
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_search_train_stations",
side_effect=get_train_stations,
),
patch(
"homeassistant.components.trafikverket_train.async_setup_entry",
@ -366,8 +414,8 @@ async def test_reauth_flow_error(
assert entry.data == {
"api_key": "1234567891",
"name": "Stockholm C to Uppsala C at 10:00",
"from": "Stockholm C",
"to": "Uppsala C",
"from": "Cst",
"to": "U",
"time": "10:00",
"weekday": ["mon", "tue", "wed", "thu", "fri", "sat", "sun"],
}
@ -377,17 +425,20 @@ async def test_reauth_flow_error(
("side_effect", "p_error"),
[
(
NoTrainAnnouncementFound,
"no_trains",
NoTrainStationFound,
{"from": "invalid_station"},
),
(
UnknownError,
"cannot_connect",
{"base": "cannot_connect"},
),
],
)
async def test_reauth_flow_error_departures(
hass: HomeAssistant, side_effect: Exception, p_error: str
hass: HomeAssistant,
side_effect: Exception,
p_error: dict[str, str],
get_train_stations: list[StationInfoModel],
) -> None:
"""Test a reauthentication flow with error."""
entry = MockConfigEntry(
@ -395,13 +446,13 @@ async def test_reauth_flow_error_departures(
data={
CONF_API_KEY: "1234567890",
CONF_NAME: "Stockholm C to Uppsala C at 10:00",
CONF_FROM: "Stockholm C",
CONF_TO: "Uppsala C",
CONF_FROM: "Cst",
CONF_TO: "U",
CONF_TIME: "10:00",
CONF_WEEKDAY: WEEKDAYS,
},
version=1,
minor_version=2,
version=2,
minor_version=1,
)
entry.add_to_hass(hass)
@ -409,10 +460,7 @@ async def test_reauth_flow_error_departures(
with (
patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_search_train_station",
),
patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_stop",
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_search_train_stations",
side_effect=side_effect(),
),
):
@ -424,11 +472,12 @@ async def test_reauth_flow_error_departures(
assert result["step_id"] == "reauth_confirm"
assert result["type"] is FlowResultType.FORM
assert result["errors"] == {"base": p_error}
assert result["errors"] == p_error
with (
patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_search_train_station",
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_search_train_stations",
side_effect=get_train_stations,
),
patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_stop",
@ -449,8 +498,8 @@ async def test_reauth_flow_error_departures(
assert entry.data == {
"api_key": "1234567891",
"name": "Stockholm C to Uppsala C at 10:00",
"from": "Stockholm C",
"to": "Uppsala C",
"from": "Cst",
"to": "U",
"time": "10:00",
"weekday": ["mon", "tue", "wed", "thu", "fri", "sat", "sun"],
}
@ -460,6 +509,7 @@ async def test_options_flow(
hass: HomeAssistant,
get_trains: list[TrainStopModel],
get_train_stop: TrainStopModel,
get_train_stations: list[StationInfoModel],
) -> None:
"""Test a reauthentication flow."""
entry = MockConfigEntry(
@ -467,24 +517,28 @@ async def test_options_flow(
data={
CONF_API_KEY: "1234567890",
CONF_NAME: "Stockholm C to Uppsala C at 10:00",
CONF_FROM: "Stockholm C",
CONF_TO: "Uppsala C",
CONF_FROM: "Cst",
CONF_TO: "U",
CONF_TIME: "10:00",
CONF_WEEKDAY: WEEKDAYS,
},
version=1,
minor_version=2,
version=2,
minor_version=1,
)
entry.add_to_hass(hass)
with (
patch(
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_search_train_station",
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_search_train_stations",
side_effect=get_train_stations,
),
patch(
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_get_next_train_stops",
return_value=get_trains,
),
patch(
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_get_train_station_from_signature",
),
patch(
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_get_train_stop",
return_value=get_train_stop,

View File

@ -4,8 +4,14 @@ from __future__ import annotations
from unittest.mock import patch
from pytrafikverket.exceptions import InvalidAuthentication, NoTrainStationFound
from pytrafikverket.models import TrainStopModel
import pytest
from pytrafikverket import (
InvalidAuthentication,
NoTrainStationFound,
StationInfoModel,
TrainStopModel,
UnknownError,
)
from syrupy.assertion import SnapshotAssertion
from homeassistant.components.trafikverket_train.const import DOMAIN
@ -28,14 +34,14 @@ async def test_unload_entry(
data=ENTRY_CONFIG,
options=OPTIONS_CONFIG,
entry_id="1",
version=1,
minor_version=2,
version=2,
minor_version=1,
)
entry.add_to_hass(hass)
with (
patch(
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_search_train_station",
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_get_train_station_from_signature",
),
patch(
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_get_next_train_stops",
@ -65,13 +71,13 @@ async def test_auth_failed(
data=ENTRY_CONFIG,
options=OPTIONS_CONFIG,
entry_id="1",
version=1,
minor_version=2,
version=2,
minor_version=1,
)
entry.add_to_hass(hass)
with patch(
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_search_train_station",
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_get_train_station_from_signature",
side_effect=InvalidAuthentication,
):
await hass.config_entries.async_setup(entry.entry_id)
@ -96,13 +102,13 @@ async def test_no_stations(
data=ENTRY_CONFIG,
options=OPTIONS_CONFIG,
entry_id="1",
version=1,
minor_version=2,
version=2,
minor_version=1,
)
entry.add_to_hass(hass)
with patch(
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_search_train_station",
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_get_train_station_from_signature",
side_effect=NoTrainStationFound,
):
await hass.config_entries.async_setup(entry.entry_id)
@ -124,8 +130,8 @@ async def test_migrate_entity_unique_id(
data=ENTRY_CONFIG,
options=OPTIONS_CONFIG,
entry_id="1",
version=1,
minor_version=2,
version=2,
minor_version=1,
)
entry.add_to_hass(hass)
@ -139,7 +145,7 @@ async def test_migrate_entity_unique_id(
with (
patch(
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_search_train_station",
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_get_train_station_from_signature",
),
patch(
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_get_next_train_stops",
@ -158,8 +164,9 @@ async def test_migrate_entity_unique_id(
async def test_migrate_entry(
hass: HomeAssistant,
get_trains: list[TrainStopModel],
get_train_stations: list[StationInfoModel],
) -> None:
"""Test migrate entry unique id."""
"""Test migrate entry."""
entry = MockConfigEntry(
domain=DOMAIN,
source=SOURCE_USER,
@ -174,7 +181,11 @@ async def test_migrate_entry(
with (
patch(
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_search_train_station",
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_get_train_station_from_signature",
),
patch(
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_search_train_stations",
side_effect=get_train_stations,
),
patch(
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_get_next_train_stops",
@ -186,8 +197,18 @@ async def test_migrate_entry(
assert entry.state is ConfigEntryState.LOADED
assert entry.version == 1
assert entry.minor_version == 2
assert entry.version == 2
assert entry.minor_version == 1
# Migration to version 2.1 changed from/to to use station signatures
assert entry.data == {
"api_key": "1234567890",
"from": "Cst",
"to": "U",
"time": None,
"weekday": ["mon", "tue", "wed", "thu", "fri", "sat", "sun"],
"name": "Stockholm C to Uppsala C",
}
# Migration to version 1.2 removed unique_id
assert entry.unique_id is None
@ -201,18 +222,73 @@ async def test_migrate_entry_from_future_version_fails(
source=SOURCE_USER,
data=ENTRY_CONFIG,
options=OPTIONS_CONFIG,
version=2,
version=3,
minor_version=1,
entry_id="1",
)
entry.add_to_hass(hass)
await hass.config_entries.async_setup(entry.entry_id)
await hass.async_block_till_done()
assert entry.state is ConfigEntryState.MIGRATION_ERROR
@pytest.mark.parametrize(
("side_effect"),
[
(InvalidAuthentication),
(NoTrainStationFound),
(UnknownError),
(Exception),
],
)
async def test_migrate_entry_fails(hass: HomeAssistant, side_effect: Exception) -> None:
"""Test migrate entry fails."""
entry = MockConfigEntry(
domain=DOMAIN,
source=SOURCE_USER,
data=ENTRY_CONFIG,
options=OPTIONS_CONFIG,
version=1,
minor_version=1,
entry_id="1",
)
entry.add_to_hass(hass)
with (
patch(
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_search_train_station",
),
patch(
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_get_next_train_stops",
return_value=get_trains,
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_search_train_stations",
side_effect=side_effect(),
),
):
await hass.config_entries.async_setup(entry.entry_id)
await hass.async_block_till_done()
assert entry.state is ConfigEntryState.MIGRATION_ERROR
async def test_migrate_entry_fails_multiple_stations(
hass: HomeAssistant,
get_multiple_train_stations: list[StationInfoModel],
) -> None:
"""Test migrate entry fails on multiple stations found."""
entry = MockConfigEntry(
domain=DOMAIN,
source=SOURCE_USER,
data=ENTRY_CONFIG,
options=OPTIONS_CONFIG,
version=1,
minor_version=1,
entry_id="1",
unique_id="321",
)
entry.add_to_hass(hass)
with (
patch(
"homeassistant.components.trafikverket_train.coordinator.TrafikverketTrain.async_search_train_stations",
side_effect=get_multiple_train_stations,
),
):
await hass.config_entries.async_setup(entry.entry_id)