Add product filtering feature to Trafikverket Train (#86343)

This commit is contained in:
G Johansson 2023-08-09 17:20:30 +02:00 committed by GitHub
parent 0317afeb17
commit 4c03077dfe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 203 additions and 20 deletions

View File

@ -15,7 +15,7 @@ from homeassistant.exceptions import ConfigEntryAuthFailed, ConfigEntryNotReady
from homeassistant.helpers import entity_registry as er from homeassistant.helpers import entity_registry as er
from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.aiohttp_client import async_get_clientsession
from .const import CONF_FROM, CONF_TO, DOMAIN, PLATFORMS from .const import CONF_FILTER_PRODUCT, CONF_FROM, CONF_TO, DOMAIN, PLATFORMS
from .coordinator import TVDataUpdateCoordinator from .coordinator import TVDataUpdateCoordinator
@ -36,7 +36,9 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
f" {entry.data[CONF_TO]}. Error: {error} " f" {entry.data[CONF_TO]}. Error: {error} "
) from error ) from error
coordinator = TVDataUpdateCoordinator(hass, entry, to_station, from_station) coordinator = TVDataUpdateCoordinator(
hass, entry, to_station, from_station, entry.options.get(CONF_FILTER_PRODUCT)
)
await coordinator.async_config_entry_first_refresh() await coordinator.async_config_entry_first_refresh()
hass.data.setdefault(DOMAIN, {})[entry.entry_id] = coordinator hass.data.setdefault(DOMAIN, {})[entry.entry_id] = coordinator
@ -49,6 +51,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
) )
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
entry.async_on_unload(entry.add_update_listener(update_listener))
return True return True
@ -57,3 +60,8 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload Trafikverket Weatherstation config entry.""" """Unload Trafikverket Weatherstation config entry."""
return await hass.config_entries.async_unload_platforms(entry, PLATFORMS) return await hass.config_entries.async_unload_platforms(entry, PLATFORMS)
async def update_listener(hass: HomeAssistant, entry: ConfigEntry) -> None:
"""Handle options update."""
await hass.config_entries.async_reload(entry.entry_id)

View File

@ -19,7 +19,7 @@ import voluptuous as vol
from homeassistant import config_entries from homeassistant import config_entries
from homeassistant.const import CONF_API_KEY, CONF_NAME, CONF_WEEKDAY, WEEKDAYS from homeassistant.const import CONF_API_KEY, CONF_NAME, CONF_WEEKDAY, WEEKDAYS
from homeassistant.core import HomeAssistant from homeassistant.core import HomeAssistant, callback
from homeassistant.data_entry_flow import FlowResult from homeassistant.data_entry_flow import FlowResult
from homeassistant.helpers.aiohttp_client import async_get_clientsession from homeassistant.helpers.aiohttp_client import async_get_clientsession
import homeassistant.helpers.config_validation as cv import homeassistant.helpers.config_validation as cv
@ -32,11 +32,15 @@ from homeassistant.helpers.selector import (
) )
import homeassistant.util.dt as dt_util import homeassistant.util.dt as dt_util
from .const import CONF_FROM, CONF_TIME, CONF_TO, DOMAIN from .const import CONF_FILTER_PRODUCT, CONF_FROM, CONF_TIME, CONF_TO, DOMAIN
from .util import create_unique_id, next_departuredate from .util import create_unique_id, next_departuredate
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
OPTION_SCHEMA = {
vol.Optional(CONF_FILTER_PRODUCT, default=""): TextSelector(),
}
DATA_SCHEMA = vol.Schema( DATA_SCHEMA = vol.Schema(
{ {
vol.Required(CONF_API_KEY): TextSelector(), vol.Required(CONF_API_KEY): TextSelector(),
@ -52,7 +56,7 @@ DATA_SCHEMA = vol.Schema(
) )
), ),
} }
) ).extend(OPTION_SCHEMA)
DATA_SCHEMA_REAUTH = vol.Schema( DATA_SCHEMA_REAUTH = vol.Schema(
{ {
vol.Required(CONF_API_KEY): cv.string, vol.Required(CONF_API_KEY): cv.string,
@ -67,6 +71,7 @@ async def validate_input(
train_to: str, train_to: str,
train_time: str | None, train_time: str | None,
weekdays: list[str], weekdays: list[str],
product_filter: str | None,
) -> dict[str, str]: ) -> dict[str, str]:
"""Validate input from user input.""" """Validate input from user input."""
errors: dict[str, str] = {} errors: dict[str, str] = {}
@ -87,9 +92,13 @@ async def validate_input(
from_station = await train_api.async_get_train_station(train_from) from_station = await train_api.async_get_train_station(train_from)
to_station = await train_api.async_get_train_station(train_to) to_station = await train_api.async_get_train_station(train_to)
if train_time: if train_time:
await train_api.async_get_train_stop(from_station, to_station, when) await train_api.async_get_train_stop(
from_station, to_station, when, product_filter
)
else: else:
await train_api.async_get_next_train_stop(from_station, to_station, when) await train_api.async_get_next_train_stop(
from_station, to_station, when, product_filter
)
except InvalidAuthentication: except InvalidAuthentication:
errors["base"] = "invalid_auth" errors["base"] = "invalid_auth"
except NoTrainStationFound: except NoTrainStationFound:
@ -117,6 +126,14 @@ class TVTrainConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
entry: config_entries.ConfigEntry | None entry: config_entries.ConfigEntry | None
@staticmethod
@callback
def async_get_options_flow(
config_entry: config_entries.ConfigEntry,
) -> TVTrainOptionsFlowHandler:
"""Get the options flow for this handler."""
return TVTrainOptionsFlowHandler(config_entry)
async def async_step_reauth(self, entry_data: Mapping[str, Any]) -> FlowResult: async def async_step_reauth(self, entry_data: Mapping[str, Any]) -> FlowResult:
"""Handle re-authentication with Trafikverket.""" """Handle re-authentication with Trafikverket."""
@ -140,6 +157,7 @@ class TVTrainConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
self.entry.data[CONF_TO], self.entry.data[CONF_TO],
self.entry.data.get(CONF_TIME), self.entry.data.get(CONF_TIME),
self.entry.data[CONF_WEEKDAY], self.entry.data[CONF_WEEKDAY],
self.entry.options.get(CONF_FILTER_PRODUCT),
) )
if not errors: if not errors:
self.hass.config_entries.async_update_entry( self.hass.config_entries.async_update_entry(
@ -170,6 +188,10 @@ class TVTrainConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
train_to: str = user_input[CONF_TO] train_to: str = user_input[CONF_TO]
train_time: str | None = user_input.get(CONF_TIME) train_time: str | None = user_input.get(CONF_TIME)
train_days: list = user_input[CONF_WEEKDAY] train_days: list = user_input[CONF_WEEKDAY]
filter_product: str | None = user_input[CONF_FILTER_PRODUCT]
if filter_product == "":
filter_product = None
name = f"{train_from} to {train_to}" name = f"{train_from} to {train_to}"
if train_time: if train_time:
@ -182,6 +204,7 @@ class TVTrainConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
train_to, train_to,
train_time, train_time,
train_days, train_days,
filter_product,
) )
if not errors: if not errors:
unique_id = create_unique_id( unique_id = create_unique_id(
@ -199,6 +222,7 @@ class TVTrainConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
CONF_TIME: train_time, CONF_TIME: train_time,
CONF_WEEKDAY: train_days, CONF_WEEKDAY: train_days,
}, },
options={CONF_FILTER_PRODUCT: filter_product},
) )
return self.async_show_form( return self.async_show_form(
@ -208,3 +232,27 @@ class TVTrainConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
), ),
errors=errors, errors=errors,
) )
class TVTrainOptionsFlowHandler(config_entries.OptionsFlowWithConfigEntry):
"""Handle Trafikverket Train options."""
async def async_step_init(
self, user_input: dict[str, Any] | None = None
) -> FlowResult:
"""Manage Trafikverket Train options."""
errors: dict[str, Any] = {}
if user_input:
if not (_filter := user_input.get(CONF_FILTER_PRODUCT)) or _filter == "":
user_input[CONF_FILTER_PRODUCT] = None
return self.async_create_entry(data=user_input)
return self.async_show_form(
step_id="init",
data_schema=self.add_suggested_values_to_schema(
vol.Schema(OPTION_SCHEMA),
user_input or self.options,
),
errors=errors,
)

View File

@ -8,3 +8,4 @@ ATTRIBUTION = "Data provided by Trafikverket"
CONF_FROM = "from" CONF_FROM = "from"
CONF_TO = "to" CONF_TO = "to"
CONF_TIME = "time" CONF_TIME = "time"
CONF_FILTER_PRODUCT = "filter_product"

View File

@ -39,6 +39,7 @@ class TrainData:
actual_time: datetime | None actual_time: datetime | None
other_info: str | None other_info: str | None
deviation: str | None deviation: str | None
product_filter: str | None
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@ -68,6 +69,7 @@ class TVDataUpdateCoordinator(DataUpdateCoordinator[TrainData]):
entry: ConfigEntry, entry: ConfigEntry,
to_station: StationInfo, to_station: StationInfo,
from_station: StationInfo, from_station: StationInfo,
filter_product: str | None,
) -> None: ) -> None:
"""Initialize the Trafikverket coordinator.""" """Initialize the Trafikverket coordinator."""
super().__init__( super().__init__(
@ -83,6 +85,7 @@ class TVDataUpdateCoordinator(DataUpdateCoordinator[TrainData]):
self.to_station: StationInfo = to_station self.to_station: StationInfo = to_station
self._time: time | None = dt_util.parse_time(entry.data[CONF_TIME]) self._time: time | None = dt_util.parse_time(entry.data[CONF_TIME])
self._weekdays: list[str] = entry.data[CONF_WEEKDAY] self._weekdays: list[str] = entry.data[CONF_WEEKDAY]
self._filter_product = filter_product
async def _async_update_data(self) -> TrainData: async def _async_update_data(self) -> TrainData:
"""Fetch data from Trafikverket.""" """Fetch data from Trafikverket."""
@ -99,11 +102,11 @@ class TVDataUpdateCoordinator(DataUpdateCoordinator[TrainData]):
try: try:
if self._time: if self._time:
state = await self._train_api.async_get_train_stop( state = await self._train_api.async_get_train_stop(
self.from_station, self.to_station, when self.from_station, self.to_station, when, self._filter_product
) )
else: else:
state = await self._train_api.async_get_next_train_stop( state = await self._train_api.async_get_next_train_stop(
self.from_station, self.to_station, when self.from_station, self.to_station, when, self._filter_product
) )
except InvalidAuthentication as error: except InvalidAuthentication as error:
raise ConfigEntryAuthFailed from error raise ConfigEntryAuthFailed from error
@ -134,6 +137,7 @@ class TVDataUpdateCoordinator(DataUpdateCoordinator[TrainData]):
actual_time=_get_as_utc(state.time_at_location), actual_time=_get_as_utc(state.time_at_location),
other_info=_get_as_joined(state.other_information), other_info=_get_as_joined(state.other_information),
deviation=_get_as_joined(state.deviations), deviation=_get_as_joined(state.deviations),
product_filter=self._filter_product,
) )
return states return states

View File

@ -1,9 +1,10 @@
"""Train information for departures and delays, provided by Trafikverket.""" """Train information for departures and delays, provided by Trafikverket."""
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable from collections.abc import Callable, Mapping
from dataclasses import dataclass from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from typing import Any
from homeassistant.components.sensor import ( from homeassistant.components.sensor import (
SensorDeviceClass, SensorDeviceClass,
@ -22,6 +23,8 @@ from homeassistant.helpers.update_coordinator import CoordinatorEntity
from .const import ATTRIBUTION, DOMAIN from .const import ATTRIBUTION, DOMAIN
from .coordinator import TrainData, TVDataUpdateCoordinator from .coordinator import TrainData, TVDataUpdateCoordinator
ATTR_PRODUCT_FILTER = "product_filter"
@dataclass @dataclass
class TrafikverketRequiredKeysMixin: class TrafikverketRequiredKeysMixin:
@ -158,3 +161,10 @@ class TrainSensor(CoordinatorEntity[TVDataUpdateCoordinator], SensorEntity):
def _handle_coordinator_update(self) -> None: def _handle_coordinator_update(self) -> None:
self._update_attr() self._update_attr()
return super()._handle_coordinator_update() return super()._handle_coordinator_update()
@property
def extra_state_attributes(self) -> Mapping[str, Any] | None:
"""Return additional attributes for Trafikverket Train sensor."""
if self.coordinator.data.product_filter:
return {ATTR_PRODUCT_FILTER: self.coordinator.data.product_filter}
return None

View File

@ -20,10 +20,12 @@
"to": "To station", "to": "To station",
"from": "From station", "from": "From station",
"time": "Time (optional)", "time": "Time (optional)",
"weekday": "Days" "weekday": "Days",
"filter_product": "Filter by product description"
}, },
"data_description": { "data_description": {
"time": "Set time to search specifically at this time of day, must be exact time as scheduled train departure" "time": "Set time to search specifically at this time of day, must be exact time as scheduled train departure",
"filter_product": "To filter by product description add the phrase here to match"
} }
}, },
"reauth_confirm": { "reauth_confirm": {
@ -33,6 +35,18 @@
} }
} }
}, },
"options": {
"step": {
"init": {
"data": {
"filter_product": "[%key:component::trafikverket_train::config::step::user::data::filter_product%]"
},
"data_description": {
"filter_product": "[%key:component::trafikverket_train::config::step::user::data_description::filter_product%]"
}
}
}
},
"selector": { "selector": {
"weekday": { "weekday": {
"options": { "options": {
@ -49,7 +63,12 @@
"entity": { "entity": {
"sensor": { "sensor": {
"departure_time": { "departure_time": {
"name": "Departure time" "name": "Departure time",
"state_attributes": {
"product_filter": {
"name": "Train filtering"
}
}
}, },
"departure_state": { "departure_state": {
"name": "Departure state", "name": "Departure state",
@ -57,28 +76,68 @@
"on_time": "On time", "on_time": "On time",
"delayed": "Delayed", "delayed": "Delayed",
"canceled": "Cancelled" "canceled": "Cancelled"
},
"state_attributes": {
"product_filter": {
"name": "[%key:component::trafikverket_train::entity::sensor::departure_time::state_attributes::product_filter::name%]"
}
} }
}, },
"cancelled": { "cancelled": {
"name": "Cancelled" "name": "Cancelled",
"state_attributes": {
"product_filter": {
"name": "[%key:component::trafikverket_train::entity::sensor::departure_time::state_attributes::product_filter::name%]"
}
}
}, },
"delayed_time": { "delayed_time": {
"name": "Delayed time" "name": "Delayed time",
"state_attributes": {
"product_filter": {
"name": "[%key:component::trafikverket_train::entity::sensor::departure_time::state_attributes::product_filter::name%]"
}
}
}, },
"planned_time": { "planned_time": {
"name": "Planned time" "name": "Planned time",
"state_attributes": {
"product_filter": {
"name": "[%key:component::trafikverket_train::entity::sensor::departure_time::state_attributes::product_filter::name%]"
}
}
}, },
"estimated_time": { "estimated_time": {
"name": "Estimated time" "name": "Estimated time",
"state_attributes": {
"product_filter": {
"name": "[%key:component::trafikverket_train::entity::sensor::departure_time::state_attributes::product_filter::name%]"
}
}
}, },
"actual_time": { "actual_time": {
"name": "Actual time" "name": "Actual time",
"state_attributes": {
"product_filter": {
"name": "[%key:component::trafikverket_train::entity::sensor::departure_time::state_attributes::product_filter::name%]"
}
}
}, },
"other_info": { "other_info": {
"name": "Other information" "name": "Other information",
"state_attributes": {
"product_filter": {
"name": "[%key:component::trafikverket_train::entity::sensor::departure_time::state_attributes::product_filter::name%]"
}
}
}, },
"deviation": { "deviation": {
"name": "Deviation" "name": "Deviation",
"state_attributes": {
"product_filter": {
"name": "[%key:component::trafikverket_train::entity::sensor::departure_time::state_attributes::product_filter::name%]"
}
}
} }
} }
} }

View File

@ -66,6 +66,7 @@ async def test_form(hass: HomeAssistant) -> None:
"time": "10:00", "time": "10:00",
"weekday": ["mon", "fri"], "weekday": ["mon", "fri"],
} }
assert result["options"] == {"filter_product": None}
assert len(mock_setup_entry.mock_calls) == 1 assert len(mock_setup_entry.mock_calls) == 1
assert result["result"].unique_id == "{}-{}-{}-{}".format( assert result["result"].unique_id == "{}-{}-{}-{}".format(
"stockholmc", "uppsalac", "10:00", "['mon', 'fri']" "stockholmc", "uppsalac", "10:00", "['mon', 'fri']"
@ -448,3 +449,55 @@ async def test_reauth_flow_error_departures(
"time": "10:00", "time": "10:00",
"weekday": ["mon", "tue", "wed", "thu", "fri", "sat", "sun"], "weekday": ["mon", "tue", "wed", "thu", "fri", "sat", "sun"],
} }
async def test_options_flow(hass: HomeAssistant) -> None:
"""Test a reauthentication flow."""
entry = MockConfigEntry(
domain=DOMAIN,
data={
CONF_API_KEY: "1234567890",
CONF_NAME: "Stockholm C to Uppsala C at 10:00",
CONF_FROM: "Stockholm C",
CONF_TO: "Uppsala C",
CONF_TIME: "10:00",
CONF_WEEKDAY: WEEKDAYS,
},
unique_id=f"stockholmc-uppsalac-10:00-{WEEKDAYS}",
)
entry.add_to_hass(hass)
with patch(
"homeassistant.components.trafikverket_train.async_setup_entry",
return_value=True,
):
assert await hass.config_entries.async_setup(entry.entry_id)
await hass.async_block_till_done()
result = await hass.config_entries.options.async_init(entry.entry_id)
assert result["type"] == FlowResultType.FORM
assert result["step_id"] == "init"
result = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input={"filter_product": "SJ Regionaltåg"},
)
await hass.async_block_till_done()
assert result["type"] == FlowResultType.CREATE_ENTRY
assert result["data"] == {"filter_product": "SJ Regionaltåg"}
result = await hass.config_entries.options.async_init(entry.entry_id)
assert result["type"] == FlowResultType.FORM
assert result["step_id"] == "init"
result = await hass.config_entries.options.async_configure(
result["flow_id"],
user_input={"filter_product": ""},
)
await hass.async_block_till_done()
assert result["type"] == FlowResultType.CREATE_ENTRY
assert result["data"] == {"filter_product": None}