Refactor Trafikverket Train to improve config flow (#97929)

* Refactor tvt

* review fixes

* review comments
This commit is contained in:
G Johansson 2023-08-07 17:25:02 +02:00 committed by GitHub
parent 15eed166ec
commit c4da5374ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 290 additions and 115 deletions

View File

@ -1341,6 +1341,7 @@ omit =
homeassistant/components/trafikverket_train/__init__.py homeassistant/components/trafikverket_train/__init__.py
homeassistant/components/trafikverket_train/coordinator.py homeassistant/components/trafikverket_train/coordinator.py
homeassistant/components/trafikverket_train/sensor.py homeassistant/components/trafikverket_train/sensor.py
homeassistant/components/trafikverket_train/util.py
homeassistant/components/trafikverket_weatherstation/__init__.py homeassistant/components/trafikverket_weatherstation/__init__.py
homeassistant/components/trafikverket_weatherstation/coordinator.py homeassistant/components/trafikverket_weatherstation/coordinator.py
homeassistant/components/trafikverket_weatherstation/sensor.py homeassistant/components/trafikverket_weatherstation/sensor.py

View File

@ -2,18 +2,24 @@
from __future__ import annotations from __future__ import annotations
from collections.abc import Mapping from collections.abc import Mapping
from datetime import datetime
import logging
from typing import Any from typing import Any
from pytrafikverket import TrafikverketTrain from pytrafikverket import TrafikverketTrain
from pytrafikverket.exceptions import ( from pytrafikverket.exceptions import (
InvalidAuthentication, InvalidAuthentication,
MultipleTrainAnnouncementFound,
MultipleTrainStationsFound, MultipleTrainStationsFound,
NoTrainAnnouncementFound,
NoTrainStationFound, NoTrainStationFound,
UnknownError,
) )
import voluptuous as vol import voluptuous as vol
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.const import CONF_API_KEY, CONF_NAME, CONF_WEEKDAY, WEEKDAYS from homeassistant.const import CONF_API_KEY, CONF_NAME, CONF_WEEKDAY, WEEKDAYS
from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResult from homeassistant.data_entry_flow import FlowResult
from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.aiohttp_client import async_get_clientsession
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
@ -22,18 +28,21 @@ from homeassistant.helpers.selector import (
SelectSelectorConfig, SelectSelectorConfig,
SelectSelectorMode, SelectSelectorMode,
TextSelector, TextSelector,
TimeSelector,
) )
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from .const import CONF_FROM, CONF_TIME, CONF_TO, DOMAIN from .const import CONF_FROM, CONF_TIME, CONF_TO, DOMAIN
from .util import create_unique_id from .util import create_unique_id, next_departuredate
_LOGGER = logging.getLogger(__name__)
DATA_SCHEMA = vol.Schema( DATA_SCHEMA = vol.Schema(
{ {
vol.Required(CONF_API_KEY): TextSelector(), vol.Required(CONF_API_KEY): TextSelector(),
vol.Required(CONF_FROM): TextSelector(), vol.Required(CONF_FROM): TextSelector(),
vol.Required(CONF_TO): TextSelector(), vol.Required(CONF_TO): TextSelector(),
vol.Optional(CONF_TIME): TextSelector(), vol.Optional(CONF_TIME): TimeSelector(),
vol.Required(CONF_WEEKDAY, default=WEEKDAYS): SelectSelector( vol.Required(CONF_WEEKDAY, default=WEEKDAYS): SelectSelector(
SelectSelectorConfig( SelectSelectorConfig(
options=WEEKDAYS, options=WEEKDAYS,
@ -51,6 +60,56 @@ DATA_SCHEMA_REAUTH = vol.Schema(
) )
async def validate_input(
hass: HomeAssistant,
api_key: str,
train_from: str,
train_to: str,
train_time: str | None,
weekdays: list[str],
) -> dict[str, str]:
"""Validate input from user input."""
errors: dict[str, str] = {}
when = dt_util.now()
if train_time:
departure_day = next_departuredate(weekdays)
if _time := dt_util.parse_time(train_time):
when = datetime.combine(
departure_day,
_time,
dt_util.get_time_zone(hass.config.time_zone),
)
try:
web_session = async_get_clientsession(hass)
train_api = TrafikverketTrain(web_session, api_key)
from_station = await train_api.async_get_train_station(train_from)
to_station = await train_api.async_get_train_station(train_to)
if train_time:
await train_api.async_get_train_stop(from_station, to_station, when)
else:
await train_api.async_get_next_train_stop(from_station, to_station, when)
except InvalidAuthentication:
errors["base"] = "invalid_auth"
except NoTrainStationFound:
errors["base"] = "invalid_station"
except MultipleTrainStationsFound:
errors["base"] = "more_stations"
except NoTrainAnnouncementFound:
errors["base"] = "no_trains"
except MultipleTrainAnnouncementFound:
errors["base"] = "multiple_trains"
except UnknownError as error:
_LOGGER.error("Unknown error occurred during validation %s", str(error))
errors["base"] = "cannot_connect"
except Exception as error: # pylint: disable=broad-exception-caught
_LOGGER.error("Unknown exception occurred during validation %s", str(error))
errors["base"] = "cannot_connect"
return errors
class TVTrainConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): class TVTrainConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
"""Handle a config flow for Trafikverket Train integration.""" """Handle a config flow for Trafikverket Train integration."""
@ -58,15 +117,6 @@ class TVTrainConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
entry: config_entries.ConfigEntry | None entry: config_entries.ConfigEntry | None
async def validate_input(
self, api_key: str, train_from: str, train_to: str
) -> None:
"""Validate input from user input."""
web_session = async_get_clientsession(self.hass)
train_api = TrafikverketTrain(web_session, api_key)
await train_api.async_get_train_station(train_from)
await train_api.async_get_train_station(train_to)
async def async_step_reauth(self, entry_data: Mapping[str, Any]) -> FlowResult: async def async_step_reauth(self, entry_data: Mapping[str, Any]) -> FlowResult:
"""Handle re-authentication with Trafikverket.""" """Handle re-authentication with Trafikverket."""
@ -83,19 +133,15 @@ class TVTrainConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
api_key = user_input[CONF_API_KEY] api_key = user_input[CONF_API_KEY]
assert self.entry is not None assert self.entry is not None
try: errors = await validate_input(
await self.validate_input( self.hass,
api_key, self.entry.data[CONF_FROM], self.entry.data[CONF_TO] api_key,
self.entry.data[CONF_FROM],
self.entry.data[CONF_TO],
self.entry.data.get(CONF_TIME),
self.entry.data[CONF_WEEKDAY],
) )
except InvalidAuthentication: if not errors:
errors["base"] = "invalid_auth"
except NoTrainStationFound:
errors["base"] = "invalid_station"
except MultipleTrainStationsFound:
errors["base"] = "more_stations"
except Exception: # pylint: disable=broad-exception-caught
errors["base"] = "cannot_connect"
else:
self.hass.config_entries.async_update_entry( self.hass.config_entries.async_update_entry(
self.entry, self.entry,
data={ data={
@ -129,20 +175,14 @@ class TVTrainConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
if train_time: if train_time:
name = f"{train_from} to {train_to} at {train_time}" name = f"{train_from} to {train_to} at {train_time}"
try: errors = await validate_input(
await self.validate_input(api_key, train_from, train_to) self.hass,
except InvalidAuthentication: api_key,
errors["base"] = "invalid_auth" train_from,
except NoTrainStationFound: train_to,
errors["base"] = "invalid_station" train_time,
except MultipleTrainStationsFound: train_days,
errors["base"] = "more_stations" )
except Exception: # pylint: disable=broad-exception-caught
errors["base"] = "cannot_connect"
else:
if train_time:
if bool(dt_util.parse_time(train_time) is None):
errors["base"] = "invalid_time"
if not errors: if not errors:
unique_id = create_unique_id( unique_id = create_unique_id(
train_from, train_to, train_time, train_days train_from, train_to, train_time, train_days
@ -163,6 +203,8 @@ class TVTrainConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
return self.async_show_form( return self.async_show_form(
step_id="user", step_id="user",
data_schema=DATA_SCHEMA, data_schema=self.add_suggested_values_to_schema(
DATA_SCHEMA, user_input or {}
),
errors=errors, errors=errors,
) )

View File

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from dataclasses import dataclass from dataclasses import dataclass
from datetime import date, datetime, time, timedelta from datetime import datetime, time, timedelta
import logging import logging
from pytrafikverket import TrafikverketTrain from pytrafikverket import TrafikverketTrain
@ -15,7 +15,7 @@ from pytrafikverket.exceptions import (
from pytrafikverket.trafikverket_train import StationInfo, TrainStop from pytrafikverket.trafikverket_train import StationInfo, TrainStop
from homeassistant.config_entries import ConfigEntry from homeassistant.config_entries import ConfigEntry
from homeassistant.const import CONF_API_KEY, CONF_WEEKDAY, WEEKDAYS from homeassistant.const import CONF_API_KEY, CONF_WEEKDAY
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant
from homeassistant.exceptions import ConfigEntryAuthFailed from homeassistant.exceptions import ConfigEntryAuthFailed
from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.aiohttp_client import async_get_clientsession
@ -23,6 +23,7 @@ from homeassistant.helpers.update_coordinator import DataUpdateCoordinator, Upda
from homeassistant.util import dt as dt_util from homeassistant.util import dt as dt_util
from .const import CONF_TIME, DOMAIN from .const import CONF_TIME, DOMAIN
from .util import next_departuredate
@dataclass @dataclass
@ -44,27 +45,6 @@ _LOGGER = logging.getLogger(__name__)
TIME_BETWEEN_UPDATES = timedelta(minutes=5) TIME_BETWEEN_UPDATES = timedelta(minutes=5)
def _next_weekday(fromdate: date, weekday: int) -> date:
"""Return the date of the next time a specific weekday happen."""
days_ahead = weekday - fromdate.weekday()
if days_ahead <= 0:
days_ahead += 7
return fromdate + timedelta(days_ahead)
def _next_departuredate(departure: list[str]) -> date:
"""Calculate the next departuredate from an array input of short days."""
today_date = date.today()
today_weekday = date.weekday(today_date)
if WEEKDAYS[today_weekday] in departure:
return today_date
for day in departure:
next_departure = WEEKDAYS.index(day)
if next_departure > today_weekday:
return _next_weekday(today_date, next_departure)
return _next_weekday(today_date, WEEKDAYS.index(departure[0]))
def _get_as_utc(date_value: datetime | None) -> datetime | None: def _get_as_utc(date_value: datetime | None) -> datetime | None:
"""Return utc datetime or None.""" """Return utc datetime or None."""
if date_value: if date_value:
@ -110,7 +90,7 @@ class TVDataUpdateCoordinator(DataUpdateCoordinator[TrainData]):
when = dt_util.now() when = dt_util.now()
state: TrainStop | None = None state: TrainStop | None = None
if self._time: if self._time:
departure_day = _next_departuredate(self._weekdays) departure_day = next_departuredate(self._weekdays)
when = datetime.combine( when = datetime.combine(
departure_day, departure_day,
self._time, self._time,

View File

@ -9,7 +9,8 @@
"invalid_auth": "[%key:common::config_flow::error::invalid_auth%]", "invalid_auth": "[%key:common::config_flow::error::invalid_auth%]",
"invalid_station": "Could not find a station with the specified name", "invalid_station": "Could not find a station with the specified name",
"more_stations": "Found multiple stations with the specified name", "more_stations": "Found multiple stations with the specified name",
"invalid_time": "Invalid time provided", "no_trains": "No train found",
"multiple_trains": "Multiple trains found",
"incorrect_api_key": "Invalid API key for selected account" "incorrect_api_key": "Invalid API key for selected account"
}, },
"step": { "step": {
@ -20,6 +21,9 @@
"from": "From station", "from": "From station",
"time": "Time (optional)", "time": "Time (optional)",
"weekday": "Days" "weekday": "Days"
},
"data_description": {
"time": "Set time to search specifically at this time of day, must be exact time as scheduled train departure"
} }
}, },
"reauth_confirm": { "reauth_confirm": {

View File

@ -1,7 +1,9 @@
"""Utils for trafikverket_train.""" """Utils for trafikverket_train."""
from __future__ import annotations from __future__ import annotations
from datetime import time from datetime import date, time, timedelta
from homeassistant.const import WEEKDAYS
def create_unique_id( def create_unique_id(
@ -13,3 +15,24 @@ def create_unique_id(
f"{from_station.casefold().replace(' ', '')}-{to_station.casefold().replace(' ', '')}" f"{from_station.casefold().replace(' ', '')}-{to_station.casefold().replace(' ', '')}"
f"-{timestr.casefold().replace(' ', '')}-{str(weekdays)}" f"-{timestr.casefold().replace(' ', '')}-{str(weekdays)}"
) )
def next_weekday(fromdate: date, weekday: int) -> date:
"""Return the date of the next time a specific weekday happen."""
days_ahead = weekday - fromdate.weekday()
if days_ahead <= 0:
days_ahead += 7
return fromdate + timedelta(days_ahead)
def next_departuredate(departure: list[str]) -> date:
"""Calculate the next departuredate from an array input of short days."""
today_date = date.today()
today_weekday = date.weekday(today_date)
if WEEKDAYS[today_weekday] in departure:
return today_date
for day in departure:
next_departure = WEEKDAYS.index(day)
if next_departure > today_weekday:
return next_weekday(today_date, next_departure)
return next_weekday(today_date, WEEKDAYS.index(departure[0]))

View File

@ -6,8 +6,11 @@ from unittest.mock import patch
import pytest import pytest
from pytrafikverket.exceptions import ( from pytrafikverket.exceptions import (
InvalidAuthentication, InvalidAuthentication,
MultipleTrainAnnouncementFound,
MultipleTrainStationsFound, MultipleTrainStationsFound,
NoTrainAnnouncementFound,
NoTrainStationFound, NoTrainStationFound,
UnknownError,
) )
from homeassistant import config_entries from homeassistant import config_entries
@ -35,11 +38,13 @@ async def test_form(hass: HomeAssistant) -> None:
with patch( with patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_station", "homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_station",
), patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_stop",
), patch( ), patch(
"homeassistant.components.trafikverket_train.async_setup_entry", "homeassistant.components.trafikverket_train.async_setup_entry",
return_value=True, return_value=True,
) as mock_setup_entry: ) as mock_setup_entry:
result2 = await hass.config_entries.flow.async_configure( result = await hass.config_entries.flow.async_configure(
result["flow_id"], result["flow_id"],
{ {
CONF_API_KEY: "1234567890", CONF_API_KEY: "1234567890",
@ -51,9 +56,9 @@ async def test_form(hass: HomeAssistant) -> None:
) )
await hass.async_block_till_done() await hass.async_block_till_done()
assert result2["type"] == FlowResultType.CREATE_ENTRY assert result["type"] == FlowResultType.CREATE_ENTRY
assert result2["title"] == "Stockholm C to Uppsala C at 10:00" assert result["title"] == "Stockholm C to Uppsala C at 10:00"
assert result2["data"] == { assert result["data"] == {
"api_key": "1234567890", "api_key": "1234567890",
"name": "Stockholm C to Uppsala C at 10:00", "name": "Stockholm C to Uppsala C at 10:00",
"from": "Stockholm C", "from": "Stockholm C",
@ -62,7 +67,7 @@ async def test_form(hass: HomeAssistant) -> None:
"weekday": ["mon", "fri"], "weekday": ["mon", "fri"],
} }
assert len(mock_setup_entry.mock_calls) == 1 assert len(mock_setup_entry.mock_calls) == 1
assert result2["result"].unique_id == "{}-{}-{}-{}".format( assert result["result"].unique_id == "{}-{}-{}-{}".format(
"stockholmc", "uppsalac", "10:00", "['mon', 'fri']" "stockholmc", "uppsalac", "10:00", "['mon', 'fri']"
) )
@ -92,11 +97,13 @@ async def test_form_entry_already_exist(hass: HomeAssistant) -> None:
with patch( with patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_station", "homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_station",
), patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_stop",
), patch( ), patch(
"homeassistant.components.trafikverket_train.async_setup_entry", "homeassistant.components.trafikverket_train.async_setup_entry",
return_value=True, return_value=True,
): ):
result2 = await hass.config_entries.flow.async_configure( result = await hass.config_entries.flow.async_configure(
result["flow_id"], result["flow_id"],
{ {
CONF_API_KEY: "1234567890", CONF_API_KEY: "1234567890",
@ -108,8 +115,8 @@ async def test_form_entry_already_exist(hass: HomeAssistant) -> None:
) )
await hass.async_block_till_done() await hass.async_block_till_done()
assert result2["type"] == FlowResultType.ABORT assert result["type"] == FlowResultType.ABORT
assert result2["reason"] == "already_configured" assert result["reason"] == "already_configured"
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -137,19 +144,21 @@ async def test_flow_fails(
hass: HomeAssistant, side_effect: Exception, base_error: str hass: HomeAssistant, side_effect: Exception, base_error: str
) -> None: ) -> None:
"""Test config flow errors.""" """Test config flow errors."""
result4 = await hass.config_entries.flow.async_init( result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER} DOMAIN, context={"source": config_entries.SOURCE_USER}
) )
assert result4["type"] == FlowResultType.FORM assert result["type"] == FlowResultType.FORM
assert result4["step_id"] == config_entries.SOURCE_USER assert result["step_id"] == config_entries.SOURCE_USER
with patch( with patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_station", "homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_station",
side_effect=side_effect(), side_effect=side_effect(),
), patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_stop",
): ):
result4 = await hass.config_entries.flow.async_configure( result = await hass.config_entries.flow.async_configure(
result4["flow_id"], result["flow_id"],
user_input={ user_input={
CONF_API_KEY: "1234567890", CONF_API_KEY: "1234567890",
CONF_FROM: "Stockholm C", CONF_FROM: "Stockholm C",
@ -157,32 +166,55 @@ async def test_flow_fails(
}, },
) )
assert result4["errors"] == {"base": base_error} assert result["errors"] == {"base": base_error}
async def test_flow_fails_incorrect_time(hass: HomeAssistant) -> None: @pytest.mark.parametrize(
"""Test config flow errors due to bad time.""" ("side_effect", "base_error"),
result5 = await hass.config_entries.flow.async_init( [
(
NoTrainAnnouncementFound,
"no_trains",
),
(
MultipleTrainAnnouncementFound,
"multiple_trains",
),
(
UnknownError,
"cannot_connect",
),
],
)
async def test_flow_fails_departures(
hass: HomeAssistant, side_effect: Exception, base_error: str
) -> None:
"""Test config flow errors."""
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER} DOMAIN, context={"source": config_entries.SOURCE_USER}
) )
assert result5["type"] == FlowResultType.FORM assert result["type"] == FlowResultType.FORM
assert result5["step_id"] == config_entries.SOURCE_USER assert result["step_id"] == config_entries.SOURCE_USER
with patch( with patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_station", "homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_station",
), patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_next_train_stop",
side_effect=side_effect(),
), patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_stop",
): ):
result6 = await hass.config_entries.flow.async_configure( result = await hass.config_entries.flow.async_configure(
result5["flow_id"], result["flow_id"],
user_input={ user_input={
CONF_API_KEY: "1234567890", CONF_API_KEY: "1234567890",
CONF_FROM: "Stockholm C", CONF_FROM: "Stockholm C",
CONF_TO: "Uppsala C", CONF_TO: "Uppsala C",
CONF_TIME: "25:25",
}, },
) )
assert result6["errors"] == {"base": "invalid_time"} assert result["errors"] == {"base": base_error}
async def test_reauth_flow(hass: HomeAssistant) -> None: async def test_reauth_flow(hass: HomeAssistant) -> None:
@ -216,18 +248,20 @@ async def test_reauth_flow(hass: HomeAssistant) -> None:
with patch( with patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_station", "homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_station",
), patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_stop",
), patch( ), patch(
"homeassistant.components.trafikverket_train.async_setup_entry", "homeassistant.components.trafikverket_train.async_setup_entry",
return_value=True, return_value=True,
): ):
result2 = await hass.config_entries.flow.async_configure( result = await hass.config_entries.flow.async_configure(
result["flow_id"], result["flow_id"],
{CONF_API_KEY: "1234567891"}, {CONF_API_KEY: "1234567891"},
) )
await hass.async_block_till_done() await hass.async_block_till_done()
assert result2["type"] == FlowResultType.ABORT assert result["type"] == FlowResultType.ABORT
assert result2["reason"] == "reauth_successful" assert result["reason"] == "reauth_successful"
assert entry.data == { assert entry.data == {
"api_key": "1234567891", "api_key": "1234567891",
"name": "Stockholm C to Uppsala C at 10:00", "name": "Stockholm C to Uppsala C at 10:00",
@ -290,31 +324,122 @@ async def test_reauth_flow_error(
with patch( with patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_station", "homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_station",
side_effect=side_effect(), side_effect=side_effect(),
), patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_stop",
): ):
result2 = await hass.config_entries.flow.async_configure( result = await hass.config_entries.flow.async_configure(
result["flow_id"], result["flow_id"],
{CONF_API_KEY: "1234567890"}, {CONF_API_KEY: "1234567890"},
) )
await hass.async_block_till_done() await hass.async_block_till_done()
assert result2["step_id"] == "reauth_confirm" assert result["step_id"] == "reauth_confirm"
assert result2["type"] == FlowResultType.FORM assert result["type"] == FlowResultType.FORM
assert result2["errors"] == {"base": p_error} assert result["errors"] == {"base": p_error}
with patch( with patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_station", "homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_station",
), patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_stop",
), patch( ), patch(
"homeassistant.components.trafikverket_train.async_setup_entry", "homeassistant.components.trafikverket_train.async_setup_entry",
return_value=True, return_value=True,
): ):
result2 = await hass.config_entries.flow.async_configure( result = await hass.config_entries.flow.async_configure(
result["flow_id"], result["flow_id"],
{CONF_API_KEY: "1234567891"}, {CONF_API_KEY: "1234567891"},
) )
await hass.async_block_till_done() await hass.async_block_till_done()
assert result2["type"] == FlowResultType.ABORT assert result["type"] == FlowResultType.ABORT
assert result2["reason"] == "reauth_successful" assert result["reason"] == "reauth_successful"
assert entry.data == {
"api_key": "1234567891",
"name": "Stockholm C to Uppsala C at 10:00",
"from": "Stockholm C",
"to": "Uppsala C",
"time": "10:00",
"weekday": ["mon", "tue", "wed", "thu", "fri", "sat", "sun"],
}
@pytest.mark.parametrize(
("side_effect", "p_error"),
[
(
NoTrainAnnouncementFound,
"no_trains",
),
(
MultipleTrainAnnouncementFound,
"multiple_trains",
),
(
UnknownError,
"cannot_connect",
),
],
)
async def test_reauth_flow_error_departures(
hass: HomeAssistant, side_effect: Exception, p_error: str
) -> None:
"""Test a reauthentication flow with error."""
entry = MockConfigEntry(
domain=DOMAIN,
data={
CONF_API_KEY: "1234567890",
CONF_NAME: "Stockholm C to Uppsala C at 10:00",
CONF_FROM: "Stockholm C",
CONF_TO: "Uppsala C",
CONF_TIME: "10:00",
CONF_WEEKDAY: WEEKDAYS,
},
unique_id=f"stockholmc-uppsalac-10:00-{WEEKDAYS}",
)
entry.add_to_hass(hass)
result = await hass.config_entries.flow.async_init(
DOMAIN,
context={
"source": config_entries.SOURCE_REAUTH,
"unique_id": entry.unique_id,
"entry_id": entry.entry_id,
},
data=entry.data,
)
with patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_station",
), patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_stop",
side_effect=side_effect(),
):
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{CONF_API_KEY: "1234567890"},
)
await hass.async_block_till_done()
assert result["step_id"] == "reauth_confirm"
assert result["type"] == FlowResultType.FORM
assert result["errors"] == {"base": p_error}
with patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_station",
), patch(
"homeassistant.components.trafikverket_train.config_flow.TrafikverketTrain.async_get_train_stop",
), patch(
"homeassistant.components.trafikverket_train.async_setup_entry",
return_value=True,
):
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{CONF_API_KEY: "1234567891"},
)
await hass.async_block_till_done()
assert result["type"] == FlowResultType.ABORT
assert result["reason"] == "reauth_successful"
assert entry.data == { assert entry.data == {
"api_key": "1234567891", "api_key": "1234567891",
"name": "Stockholm C to Uppsala C at 10:00", "name": "Stockholm C to Uppsala C at 10:00",