Add optional category in OptionsFlow to holiday (#129514)

This commit is contained in:
G Johansson 2024-12-17 20:20:26 +01:00 committed by GitHub
parent e9e8228f07
commit d785c4b0b1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 350 additions and 69 deletions

View File

@ -11,7 +11,7 @@ from homeassistant.const import CONF_COUNTRY, Platform
from homeassistant.core import HomeAssistant
from homeassistant.setup import SetupPhases, async_pause_setup
from .const import CONF_PROVINCE
from .const import CONF_CATEGORIES, CONF_PROVINCE
PLATFORMS: list[Platform] = [Platform.CALENDAR]
@ -20,6 +20,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up Holiday from a config entry."""
country: str = entry.data[CONF_COUNTRY]
province: str | None = entry.data.get(CONF_PROVINCE)
categories: list[str] | None = entry.options.get(CONF_CATEGORIES)
# We only import here to ensure that that its not imported later
# in the event loop since the platforms will call country_holidays
@ -29,14 +30,20 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
# the holidays library and it is not thread safe to import it in parallel
# https://github.com/python/cpython/issues/83065
await hass.async_add_import_executor_job(
partial(country_holidays, country, subdiv=province)
partial(country_holidays, country, subdiv=province, categories=categories)
)
await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS)
entry.async_on_unload(entry.add_update_listener(update_listener))
return True
async def update_listener(hass: HomeAssistant, entry: ConfigEntry) -> None:
"""Handle options update."""
await hass.config_entries.async_reload(entry.entry_id)
async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Unload a config entry."""
return await hass.config_entries.async_unload_platforms(entry, PLATFORMS)

View File

@ -4,7 +4,7 @@ from __future__ import annotations
from datetime import datetime, timedelta
from holidays import HolidayBase, country_holidays
from holidays import PUBLIC, HolidayBase, country_holidays
from homeassistant.components.calendar import CalendarEntity, CalendarEvent
from homeassistant.config_entries import ConfigEntry
@ -15,18 +15,27 @@ from homeassistant.helpers.entity_platform import AddEntitiesCallback
from homeassistant.helpers.event import async_track_point_in_utc_time
from homeassistant.util import dt as dt_util
from .const import CONF_PROVINCE, DOMAIN
from .const import CONF_CATEGORIES, CONF_PROVINCE, DOMAIN
def _get_obj_holidays_and_language(
country: str, province: str | None, language: str
country: str,
province: str | None,
language: str,
selected_categories: list[str] | None,
) -> tuple[HolidayBase, str]:
"""Get the object for the requested country and year."""
if selected_categories is None:
categories = [PUBLIC]
else:
categories = [PUBLIC, *selected_categories]
obj_holidays = country_holidays(
country,
subdiv=province,
years={dt_util.now().year, dt_util.now().year + 1},
language=language,
categories=categories,
)
if language == "en":
for lang in obj_holidays.supported_languages:
@ -36,6 +45,7 @@ def _get_obj_holidays_and_language(
subdiv=province,
years={dt_util.now().year, dt_util.now().year + 1},
language=lang,
categories=categories,
)
language = lang
break
@ -49,6 +59,7 @@ def _get_obj_holidays_and_language(
subdiv=province,
years={dt_util.now().year, dt_util.now().year + 1},
language=default_language,
categories=categories,
)
language = default_language
@ -63,10 +74,11 @@ async def async_setup_entry(
"""Set up the Holiday Calendar config entry."""
country: str = config_entry.data[CONF_COUNTRY]
province: str | None = config_entry.data.get(CONF_PROVINCE)
categories: list[str] | None = config_entry.options.get(CONF_CATEGORIES)
language = hass.config.language
obj_holidays, language = await hass.async_add_executor_job(
_get_obj_holidays_and_language, country, province, language
_get_obj_holidays_and_language, country, province, language, categories
)
async_add_entities(
@ -76,6 +88,7 @@ async def async_setup_entry(
country,
province,
language,
categories,
obj_holidays,
config_entry.entry_id,
)
@ -99,6 +112,7 @@ class HolidayCalendarEntity(CalendarEntity):
country: str,
province: str | None,
language: str,
categories: list[str] | None,
obj_holidays: HolidayBase,
unique_id: str,
) -> None:
@ -107,6 +121,7 @@ class HolidayCalendarEntity(CalendarEntity):
self._province = province
self._location = name
self._language = language
self._categories = categories
self._attr_unique_id = unique_id
self._attr_device_info = DeviceInfo(
identifiers={(DOMAIN, unique_id)},
@ -172,6 +187,7 @@ class HolidayCalendarEntity(CalendarEntity):
subdiv=self._province,
years=list({start_date.year, end_date.year}),
language=self._language,
categories=self._categories,
)
event_list: list[CalendarEvent] = []

View File

@ -5,11 +5,17 @@ from __future__ import annotations
from typing import Any
from babel import Locale, UnknownLocaleError
from holidays import list_supported_countries
from holidays import PUBLIC, country_holidays, list_supported_countries
import voluptuous as vol
from homeassistant.config_entries import ConfigFlow, ConfigFlowResult
from homeassistant.config_entries import (
ConfigEntry,
ConfigFlow,
ConfigFlowResult,
OptionsFlow,
)
from homeassistant.const import CONF_COUNTRY
from homeassistant.core import callback
from homeassistant.helpers.selector import (
CountrySelector,
CountrySelectorConfig,
@ -17,12 +23,47 @@ from homeassistant.helpers.selector import (
SelectSelectorConfig,
SelectSelectorMode,
)
from homeassistant.util import dt as dt_util
from .const import CONF_PROVINCE, DOMAIN
from .const import CONF_CATEGORIES, CONF_PROVINCE, DOMAIN
SUPPORTED_COUNTRIES = list_supported_countries(include_aliases=False)
def get_optional_categories(country: str) -> list[str]:
"""Return the country categories.
public holidays are always included so they
don't need to be presented to the user.
"""
country_data = country_holidays(country, years=dt_util.utcnow().year)
return [
category for category in country_data.supported_categories if category != PUBLIC
]
def get_options_schema(country: str) -> vol.Schema:
"""Return the options schema."""
schema = {}
if provinces := SUPPORTED_COUNTRIES[country]:
schema[vol.Optional(CONF_PROVINCE)] = SelectSelector(
SelectSelectorConfig(
options=provinces,
mode=SelectSelectorMode.DROPDOWN,
)
)
if categories := get_optional_categories(country):
schema[vol.Optional(CONF_CATEGORIES)] = SelectSelector(
SelectSelectorConfig(
options=categories,
multiple=True,
mode=SelectSelectorMode.DROPDOWN,
translation_key="categories",
)
)
return vol.Schema(schema)
class HolidayConfigFlow(ConfigFlow, domain=DOMAIN):
"""Handle a config flow for Holiday."""
@ -32,6 +73,12 @@ class HolidayConfigFlow(ConfigFlow, domain=DOMAIN):
"""Initialize the config flow."""
self.data: dict[str, Any] = {}
@staticmethod
@callback
def async_get_options_flow(config_entry: ConfigEntry) -> HolidayOptionsFlowHandler:
"""Get the options flow for this handler."""
return HolidayOptionsFlowHandler()
async def async_step_user(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
@ -41,8 +88,11 @@ class HolidayConfigFlow(ConfigFlow, domain=DOMAIN):
selected_country = user_input[CONF_COUNTRY]
if SUPPORTED_COUNTRIES[selected_country]:
return await self.async_step_province()
options_schema = await self.hass.async_add_executor_job(
get_options_schema, selected_country
)
if options_schema.schema:
return await self.async_step_options()
self._async_abort_entries_match({CONF_COUNTRY: user_input[CONF_COUNTRY]})
@ -67,24 +117,22 @@ class HolidayConfigFlow(ConfigFlow, domain=DOMAIN):
}
)
return self.async_show_form(step_id="user", data_schema=user_schema)
return self.async_show_form(data_schema=user_schema)
async def async_step_province(
async def async_step_options(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Handle the province step."""
"""Handle the options step."""
if user_input is not None:
combined_input: dict[str, Any] = {**self.data, **user_input}
country = self.data[CONF_COUNTRY]
data = {CONF_COUNTRY: country}
options: dict[str, Any] | None = None
if province := user_input.get(CONF_PROVINCE):
data[CONF_PROVINCE] = province
if categories := user_input.get(CONF_CATEGORIES):
options = {CONF_CATEGORIES: categories}
country = combined_input[CONF_COUNTRY]
province = combined_input.get(CONF_PROVINCE)
self._async_abort_entries_match(
{
CONF_COUNTRY: country,
CONF_PROVINCE: province,
}
)
self._async_abort_entries_match({**data, **(options or {})})
try:
locale = Locale.parse(self.hass.config.language, sep="-")
@ -95,38 +143,33 @@ class HolidayConfigFlow(ConfigFlow, domain=DOMAIN):
province_str = f", {province}" if province else ""
name = f"{locale.territories[country]}{province_str}"
return self.async_create_entry(title=name, data=combined_input)
return self.async_create_entry(title=name, data=data, options=options)
province_schema = vol.Schema(
{
vol.Optional(CONF_PROVINCE): SelectSelector(
SelectSelectorConfig(
options=SUPPORTED_COUNTRIES[self.data[CONF_COUNTRY]],
mode=SelectSelectorMode.DROPDOWN,
)
),
}
options_schema = await self.hass.async_add_executor_job(
get_options_schema, self.data[CONF_COUNTRY]
)
return self.async_show_form(
step_id="options",
data_schema=options_schema,
description_placeholders={CONF_COUNTRY: self.data[CONF_COUNTRY]},
)
return self.async_show_form(step_id="province", data_schema=province_schema)
async def async_step_reconfigure(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Handle the re-configuration of a province."""
"""Handle the re-configuration of the options."""
reconfigure_entry = self._get_reconfigure_entry()
if user_input is not None:
combined_input: dict[str, Any] = {**reconfigure_entry.data, **user_input}
country = reconfigure_entry.data[CONF_COUNTRY]
data = {CONF_COUNTRY: country}
options: dict[str, Any] | None = None
if province := user_input.get(CONF_PROVINCE):
data[CONF_PROVINCE] = province
if categories := user_input.get(CONF_CATEGORIES):
options = {CONF_CATEGORIES: categories}
country = combined_input[CONF_COUNTRY]
province = combined_input.get(CONF_PROVINCE)
self._async_abort_entries_match(
{
CONF_COUNTRY: country,
CONF_PROVINCE: province,
}
)
self._async_abort_entries_match({**data, **(options or {})})
try:
locale = Locale.parse(self.hass.config.language, sep="-")
@ -137,21 +180,60 @@ class HolidayConfigFlow(ConfigFlow, domain=DOMAIN):
province_str = f", {province}" if province else ""
name = f"{locale.territories[country]}{province_str}"
if options:
return self.async_update_reload_and_abort(
reconfigure_entry, title=name, data=data, options=options
)
return self.async_update_reload_and_abort(
reconfigure_entry, title=name, data=combined_input
reconfigure_entry, title=name, data=data
)
province_schema = vol.Schema(
options_schema = await self.hass.async_add_executor_job(
get_options_schema, reconfigure_entry.data[CONF_COUNTRY]
)
return self.async_show_form(
data_schema=options_schema,
description_placeholders={
CONF_COUNTRY: reconfigure_entry.data[CONF_COUNTRY]
},
)
class HolidayOptionsFlowHandler(OptionsFlow):
"""Handle Holiday options."""
async def async_step_init(
self, user_input: dict[str, Any] | None = None
) -> ConfigFlowResult:
"""Manage Holiday options."""
if user_input is not None:
return self.async_create_entry(data=user_input)
categories = await self.hass.async_add_executor_job(
get_optional_categories, self.config_entry.data[CONF_COUNTRY]
)
if not categories:
return self.async_abort(reason="no_categories")
schema = vol.Schema(
{
vol.Optional(CONF_PROVINCE): SelectSelector(
vol.Optional(CONF_CATEGORIES): SelectSelector(
SelectSelectorConfig(
options=SUPPORTED_COUNTRIES[
reconfigure_entry.data[CONF_COUNTRY]
],
options=categories,
multiple=True,
mode=SelectSelectorMode.DROPDOWN,
translation_key="categories",
)
)
}
)
return self.async_show_form(step_id="reconfigure", data_schema=province_schema)
return self.async_show_form(
data_schema=self.add_suggested_values_to_schema(
schema, self.config_entry.options
),
description_placeholders={
CONF_COUNTRY: self.config_entry.data[CONF_COUNTRY]
},
)

View File

@ -5,3 +5,4 @@ from typing import Final
DOMAIN: Final = "holiday"
CONF_PROVINCE: Final = "province"
CONF_CATEGORIES: Final = "categories"

View File

@ -2,7 +2,7 @@
"title": "Holiday",
"config": {
"abort": {
"already_configured": "Already configured. Only a single configuration for country/province combination possible.",
"already_configured": "Already configured. Only a single configuration for country/province/categories combination is possible.",
"reconfigure_successful": "[%key:common::config_flow::abort::reconfigure_successful%]"
},
"step": {
@ -11,16 +11,62 @@
"country": "Country"
}
},
"province": {
"options": {
"data": {
"province": "Province"
"province": "Province",
"categories": "Categories"
},
"data_description": {
"province": "Optionally choose a province / subdivision of {country}",
"categories": "Optionally choose additional holiday categories, public holidays are already included"
}
},
"reconfigure": {
"data": {
"province": "[%key:component::holiday::config::step::province::data::province%]"
"province": "[%key:component::holiday::config::step::options::data::province%]",
"categories": "[%key:component::holiday::config::step::options::data::categories%]"
},
"data_description": {
"province": "[%key:component::holiday::config::step::options::data_description::province%]",
"categories": "[%key:component::holiday::config::step::options::data_description::categories%]"
}
}
}
},
"options": {
"abort": {
"already_configured": "[%key:component::holiday::config::abort::already_configured%]",
"no_categories": "The country has no additional categories to configure."
},
"step": {
"init": {
"data": {
"categories": "[%key:component::holiday::config::step::options::data::categories%]"
},
"data_description": {
"categories": "[%key:component::holiday::config::step::options::data_description::categories%]"
}
}
}
},
"selector": {
"device_class": {
"options": {
"armed_forces": "Armed forces",
"bank": "Bank",
"catholic": "Catholic",
"chinese": "Chinese",
"christian": "Christian",
"government": "Government",
"half_day": "Half day",
"hebrew": "Hebrew",
"hindu": "Hindu",
"islamic": "Islamic",
"optional": "Optional",
"school": "School",
"unofficial": "Unofficial",
"workday": "Workday"
}
}
}
}

View File

@ -1,19 +1,25 @@
"""Test the Holiday config flow."""
from datetime import datetime
from unittest.mock import AsyncMock
from freezegun.api import FrozenDateTimeFactory
from holidays import UNOFFICIAL
import pytest
from homeassistant import config_entries
from homeassistant.components.holiday.const import CONF_PROVINCE, DOMAIN
from homeassistant.const import CONF_COUNTRY
from homeassistant.components.holiday.const import (
CONF_CATEGORIES,
CONF_PROVINCE,
DOMAIN,
)
from homeassistant.const import CONF_COUNTRY, STATE_OFF, STATE_ON
from homeassistant.core import HomeAssistant
from homeassistant.data_entry_flow import FlowResultType
from homeassistant.util import dt as dt_util
from tests.common import MockConfigEntry
pytestmark = pytest.mark.usefixtures("mock_setup_entry")
async def test_form(hass: HomeAssistant, mock_setup_entry: AsyncMock) -> None:
"""Test we get the form."""
@ -49,6 +55,7 @@ async def test_form(hass: HomeAssistant, mock_setup_entry: AsyncMock) -> None:
assert len(mock_setup_entry.mock_calls) == 1
@pytest.mark.usefixtures("mock_setup_entry")
async def test_form_no_subdivision(hass: HomeAssistant) -> None:
"""Test we get the forms correctly without subdivision."""
result = await hass.config_entries.flow.async_init(
@ -71,6 +78,7 @@ async def test_form_no_subdivision(hass: HomeAssistant) -> None:
}
@pytest.mark.usefixtures("mock_setup_entry")
async def test_form_translated_title(hass: HomeAssistant) -> None:
"""Test the title gets translated."""
hass.config.language = "de"
@ -90,6 +98,7 @@ async def test_form_translated_title(hass: HomeAssistant) -> None:
assert result2["title"] == "Schweden"
@pytest.mark.usefixtures("mock_setup_entry")
async def test_single_combination_country_province(hass: HomeAssistant) -> None:
"""Test that configuring more than one instance is rejected."""
data_de = {
@ -129,6 +138,7 @@ async def test_single_combination_country_province(hass: HomeAssistant) -> None:
assert result_de_step2["reason"] == "already_configured"
@pytest.mark.usefixtures("mock_setup_entry")
async def test_form_babel_unresolved_language(hass: HomeAssistant) -> None:
"""Test the config flow if using not babel supported language."""
hass.config.language = "en-XX"
@ -175,6 +185,7 @@ async def test_form_babel_unresolved_language(hass: HomeAssistant) -> None:
}
@pytest.mark.usefixtures("mock_setup_entry")
async def test_form_babel_replace_dash_with_underscore(hass: HomeAssistant) -> None:
"""Test the config flow if using language with dash."""
hass.config.language = "en-GB"
@ -221,7 +232,8 @@ async def test_form_babel_replace_dash_with_underscore(hass: HomeAssistant) -> N
}
async def test_reconfigure(hass: HomeAssistant, mock_setup_entry: AsyncMock) -> None:
@pytest.mark.usefixtures("mock_setup_entry")
async def test_reconfigure(hass: HomeAssistant) -> None:
"""Test reconfigure flow."""
entry = MockConfigEntry(
domain=DOMAIN,
@ -248,9 +260,38 @@ async def test_reconfigure(hass: HomeAssistant, mock_setup_entry: AsyncMock) ->
assert entry.data == {"country": "DE", "province": "NW"}
async def test_reconfigure_incorrect_language(
hass: HomeAssistant, mock_setup_entry: AsyncMock
) -> None:
@pytest.mark.usefixtures("mock_setup_entry")
async def test_reconfigure_with_categories(hass: HomeAssistant) -> None:
"""Test reconfigure flow with categories."""
entry = MockConfigEntry(
domain=DOMAIN,
title="Unites States, TX",
data={"country": "US", "province": "TX"},
)
entry.add_to_hass(hass)
result = await entry.start_reconfigure_flow(hass)
assert result["type"] is FlowResultType.FORM
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
CONF_PROVINCE: "AL",
CONF_CATEGORIES: [UNOFFICIAL],
},
)
await hass.async_block_till_done()
assert result["type"] is FlowResultType.ABORT
assert result["reason"] == "reconfigure_successful"
entry = hass.config_entries.async_get_entry(entry.entry_id)
assert entry.title == "United States, AL"
assert entry.data == {CONF_COUNTRY: "US", CONF_PROVINCE: "AL"}
assert entry.options == {CONF_CATEGORIES: ["unofficial"]}
@pytest.mark.usefixtures("mock_setup_entry")
async def test_reconfigure_incorrect_language(hass: HomeAssistant) -> None:
"""Test reconfigure flow default to English."""
hass.config.language = "en-XX"
@ -279,9 +320,8 @@ async def test_reconfigure_incorrect_language(
assert entry.data == {"country": "DE", "province": "NW"}
async def test_reconfigure_entry_exists(
hass: HomeAssistant, mock_setup_entry: AsyncMock
) -> None:
@pytest.mark.usefixtures("mock_setup_entry")
async def test_reconfigure_entry_exists(hass: HomeAssistant) -> None:
"""Test reconfigure flow stops if other entry already exist."""
entry = MockConfigEntry(
domain=DOMAIN,
@ -312,3 +352,92 @@ async def test_reconfigure_entry_exists(
entry = hass.config_entries.async_get_entry(entry.entry_id)
assert entry.title == "Germany, BW"
assert entry.data == {"country": "DE", "province": "BW"}
async def test_form_with_options(
hass: HomeAssistant,
freezer: FrozenDateTimeFactory,
) -> None:
"""Test the flow with configuring options."""
await hass.config.async_set_time_zone("America/Chicago")
zone = await dt_util.async_get_time_zone("America/Chicago")
# Oct 31st is a Friday. Unofficial holiday as Halloween
freezer.move_to(datetime(2024, 10, 31, 12, 0, 0, tzinfo=zone))
result = await hass.config_entries.flow.async_init(
DOMAIN, context={"source": config_entries.SOURCE_USER}
)
assert result["type"] is FlowResultType.FORM
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
CONF_COUNTRY: "US",
},
)
await hass.async_block_till_done()
assert result["type"] is FlowResultType.FORM
result = await hass.config_entries.flow.async_configure(
result["flow_id"],
{
CONF_PROVINCE: "TX",
CONF_CATEGORIES: [UNOFFICIAL],
},
)
await hass.async_block_till_done(wait_background_tasks=True)
assert result["type"] is FlowResultType.CREATE_ENTRY
assert result["title"] == "United States, TX"
assert result["data"] == {
CONF_COUNTRY: "US",
CONF_PROVINCE: "TX",
}
assert result["options"] == {
CONF_CATEGORIES: ["unofficial"],
}
state = hass.states.get("calendar.united_states_tx")
assert state
assert state.state == STATE_ON
entries = hass.config_entries.async_entries(DOMAIN)
entry = entries[0]
result = await hass.config_entries.options.async_init(entry.entry_id)
assert result["type"] is FlowResultType.FORM
assert result["step_id"] == "init"
result = await hass.config_entries.options.async_configure(
result["flow_id"],
{CONF_CATEGORIES: []},
)
await hass.async_block_till_done()
assert result["type"] is FlowResultType.CREATE_ENTRY
assert result["data"] == {
CONF_CATEGORIES: [],
}
state = hass.states.get("calendar.united_states_tx")
assert state
assert state.state == STATE_OFF
@pytest.mark.usefixtures("mock_setup_entry")
async def test_options_abort_no_categories(hass: HomeAssistant) -> None:
"""Test the options flow abort if no categories to select."""
config_entry = MockConfigEntry(
domain=DOMAIN,
data={CONF_COUNTRY: "SE"},
title="Sweden",
)
config_entry.add_to_hass(hass)
await hass.config_entries.async_setup(config_entry.entry_id)
await hass.async_block_till_done()
result = await hass.config_entries.options.async_init(config_entry.entry_id)
assert result["type"] is FlowResultType.ABORT
assert result["reason"] == "no_categories"