Prevent zero interval in Calendar get_events service (#139378)

* Prevent zero interval in Calendar get_events service

* Fix holiday calendar tests

* Remove redundant entity_id

* Use translation for exception

* Replace check with voluptuous validator

* Revert strings.xml
This commit is contained in:
Abílio Costa 2025-03-04 08:52:29 +00:00 committed by Franck Nijhof
parent 50aefc3653
commit 0940fc7806
No known key found for this signature in database
GPG Key ID: D62583BA8AB11CA3
3 changed files with 81 additions and 7 deletions

View File

@ -153,6 +153,27 @@ def _has_min_duration(
return validate return validate
def _has_positive_interval(
start_key: str, end_key: str, duration_key: str
) -> Callable[[dict[str, Any]], dict[str, Any]]:
"""Verify that the time span between start and end is greater than zero."""
def validate(obj: dict[str, Any]) -> dict[str, Any]:
if (duration := obj.get(duration_key)) is not None:
if duration <= datetime.timedelta(seconds=0):
raise vol.Invalid(f"Expected positive duration ({duration})")
return obj
if (start := obj.get(start_key)) and (end := obj.get(end_key)):
if start >= end:
raise vol.Invalid(
f"Expected end time to be after start time ({start}, {end})"
)
return obj
return validate
def _has_same_type(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]]: def _has_same_type(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]]:
"""Verify that all values are of the same type.""" """Verify that all values are of the same type."""
@ -281,6 +302,7 @@ SERVICE_GET_EVENTS_SCHEMA: Final = vol.All(
), ),
} }
), ),
_has_positive_interval(EVENT_START_DATETIME, EVENT_END_DATETIME, EVENT_DURATION),
) )
@ -870,6 +892,7 @@ async def async_get_events_service(
end = start + service_call.data[EVENT_DURATION] end = start + service_call.data[EVENT_DURATION]
else: else:
end = service_call.data[EVENT_END_DATETIME] end = service_call.data[EVENT_END_DATETIME]
calendar_event_list = await calendar.async_get_events( calendar_event_list = await calendar.async_get_events(
calendar.hass, dt_util.as_local(start), dt_util.as_local(end) calendar.hass, dt_util.as_local(start), dt_util.as_local(end)
) )

View File

@ -5,6 +5,7 @@ from __future__ import annotations
from collections.abc import Generator from collections.abc import Generator
from datetime import timedelta from datetime import timedelta
from http import HTTPStatus from http import HTTPStatus
import re
from typing import Any from typing import Any
from freezegun import freeze_time from freezegun import freeze_time
@ -448,7 +449,7 @@ async def test_list_events_service(
service: str, service: str,
expected: dict[str, Any], expected: dict[str, Any],
) -> None: ) -> None:
"""Test listing events from the service call using exlplicit start and end time. """Test listing events from the service call using explicit start and end time.
This test uses a fixed date/time so that it can deterministically test the This test uses a fixed date/time so that it can deterministically test the
string output values. string output values.
@ -553,3 +554,53 @@ async def test_list_events_missing_fields(hass: HomeAssistant) -> None:
blocking=True, blocking=True,
return_response=True, return_response=True,
) )
@pytest.mark.parametrize(
"frozen_time", ["2023-06-22 10:30:00+00:00"], ids=["frozen_time"]
)
@pytest.mark.parametrize(
("service_data", "error_msg"),
[
(
{
"start_date_time": "2023-06-22T04:30:00-06:00",
"end_date_time": "2023-06-22T04:30:00-06:00",
},
"Expected end time to be after start time (2023-06-22 04:30:00-06:00, 2023-06-22 04:30:00-06:00)",
),
(
{
"start_date_time": "2023-06-22T04:30:00",
"end_date_time": "2023-06-22T04:30:00",
},
"Expected end time to be after start time (2023-06-22 04:30:00, 2023-06-22 04:30:00)",
),
(
{"start_date_time": "2023-06-22", "end_date_time": "2023-06-22"},
"Expected end time to be after start time (2023-06-22 00:00:00, 2023-06-22 00:00:00)",
),
(
{"start_date_time": "2023-06-22 10:00:00", "duration": "0"},
"Expected positive duration (0:00:00)",
),
],
)
async def test_list_events_service_same_dates(
hass: HomeAssistant,
service_data: dict[str, str],
error_msg: str,
) -> None:
"""Test listing events from the service call using the same start and end time."""
with pytest.raises(vol.error.MultipleInvalid, match=re.escape(error_msg)):
await hass.services.async_call(
DOMAIN,
SERVICE_GET_EVENTS,
service_data={
"entity_id": "calendar.calendar_1",
**service_data,
},
blocking=True,
return_response=True,
)

View File

@ -49,7 +49,7 @@ async def test_holiday_calendar_entity(
SERVICE_GET_EVENTS, SERVICE_GET_EVENTS,
{ {
"entity_id": "calendar.united_states_ak", "entity_id": "calendar.united_states_ak",
"end_date_time": dt_util.now(), "end_date_time": dt_util.now() + timedelta(hours=1),
}, },
blocking=True, blocking=True,
return_response=True, return_response=True,
@ -135,7 +135,7 @@ async def test_default_language(
SERVICE_GET_EVENTS, SERVICE_GET_EVENTS,
{ {
"entity_id": "calendar.france_bl", "entity_id": "calendar.france_bl",
"end_date_time": dt_util.now(), "end_date_time": dt_util.now() + timedelta(hours=1),
}, },
blocking=True, blocking=True,
return_response=True, return_response=True,
@ -164,7 +164,7 @@ async def test_default_language(
SERVICE_GET_EVENTS, SERVICE_GET_EVENTS,
{ {
"entity_id": "calendar.france_bl", "entity_id": "calendar.france_bl",
"end_date_time": dt_util.now(), "end_date_time": dt_util.now() + timedelta(hours=1),
}, },
blocking=True, blocking=True,
return_response=True, return_response=True,
@ -211,7 +211,7 @@ async def test_no_language(
SERVICE_GET_EVENTS, SERVICE_GET_EVENTS,
{ {
"entity_id": "calendar.albania", "entity_id": "calendar.albania",
"end_date_time": dt_util.now(), "end_date_time": dt_util.now() + timedelta(hours=1),
}, },
blocking=True, blocking=True,
return_response=True, return_response=True,
@ -308,7 +308,7 @@ async def test_language_not_exist(
SERVICE_GET_EVENTS, SERVICE_GET_EVENTS,
{ {
"entity_id": "calendar.norge", "entity_id": "calendar.norge",
"end_date_time": dt_util.now(), "end_date_time": dt_util.now() + timedelta(hours=1),
}, },
blocking=True, blocking=True,
return_response=True, return_response=True,
@ -336,7 +336,7 @@ async def test_language_not_exist(
SERVICE_GET_EVENTS, SERVICE_GET_EVENTS,
{ {
"entity_id": "calendar.norge", "entity_id": "calendar.norge",
"end_date_time": dt_util.now(), "end_date_time": dt_util.now() + timedelta(hours=1),
}, },
blocking=True, blocking=True,
return_response=True, return_response=True,