diff --git a/homeassistant/components/calendar/__init__.py b/homeassistant/components/calendar/__init__.py index 40d6952fa64..96bf717c3ac 100644 --- a/homeassistant/components/calendar/__init__.py +++ b/homeassistant/components/calendar/__init__.py @@ -153,6 +153,27 @@ def _has_min_duration( return validate +def _has_positive_interval( + start_key: str, end_key: str, duration_key: str +) -> Callable[[dict[str, Any]], dict[str, Any]]: + """Verify that the time span between start and end is greater than zero.""" + + def validate(obj: dict[str, Any]) -> dict[str, Any]: + if (duration := obj.get(duration_key)) is not None: + if duration <= datetime.timedelta(seconds=0): + raise vol.Invalid(f"Expected positive duration ({duration})") + return obj + + if (start := obj.get(start_key)) and (end := obj.get(end_key)): + if start >= end: + raise vol.Invalid( + f"Expected end time to be after start time ({start}, {end})" + ) + return obj + + return validate + + def _has_same_type(*keys: Any) -> Callable[[dict[str, Any]], dict[str, Any]]: """Verify that all values are of the same type.""" @@ -281,6 +302,7 @@ SERVICE_GET_EVENTS_SCHEMA: Final = vol.All( ), } ), + _has_positive_interval(EVENT_START_DATETIME, EVENT_END_DATETIME, EVENT_DURATION), ) @@ -870,6 +892,7 @@ async def async_get_events_service( end = start + service_call.data[EVENT_DURATION] else: end = service_call.data[EVENT_END_DATETIME] + calendar_event_list = await calendar.async_get_events( calendar.hass, dt_util.as_local(start), dt_util.as_local(end) ) diff --git a/tests/components/calendar/test_init.py b/tests/components/calendar/test_init.py index 2d712f408c2..6de0a7ef936 100644 --- a/tests/components/calendar/test_init.py +++ b/tests/components/calendar/test_init.py @@ -5,6 +5,7 @@ from __future__ import annotations from collections.abc import Generator from datetime import timedelta from http import HTTPStatus +import re from typing import Any from freezegun import freeze_time @@ -448,7 +449,7 @@ async def test_list_events_service( service: str, expected: dict[str, Any], ) -> None: - """Test listing events from the service call using exlplicit start and end time. + """Test listing events from the service call using explicit start and end time. This test uses a fixed date/time so that it can deterministically test the string output values. @@ -553,3 +554,53 @@ async def test_list_events_missing_fields(hass: HomeAssistant) -> None: blocking=True, return_response=True, ) + + +@pytest.mark.parametrize( + "frozen_time", ["2023-06-22 10:30:00+00:00"], ids=["frozen_time"] +) +@pytest.mark.parametrize( + ("service_data", "error_msg"), + [ + ( + { + "start_date_time": "2023-06-22T04:30:00-06:00", + "end_date_time": "2023-06-22T04:30:00-06:00", + }, + "Expected end time to be after start time (2023-06-22 04:30:00-06:00, 2023-06-22 04:30:00-06:00)", + ), + ( + { + "start_date_time": "2023-06-22T04:30:00", + "end_date_time": "2023-06-22T04:30:00", + }, + "Expected end time to be after start time (2023-06-22 04:30:00, 2023-06-22 04:30:00)", + ), + ( + {"start_date_time": "2023-06-22", "end_date_time": "2023-06-22"}, + "Expected end time to be after start time (2023-06-22 00:00:00, 2023-06-22 00:00:00)", + ), + ( + {"start_date_time": "2023-06-22 10:00:00", "duration": "0"}, + "Expected positive duration (0:00:00)", + ), + ], +) +async def test_list_events_service_same_dates( + hass: HomeAssistant, + service_data: dict[str, str], + error_msg: str, +) -> None: + """Test listing events from the service call using the same start and end time.""" + + with pytest.raises(vol.error.MultipleInvalid, match=re.escape(error_msg)): + await hass.services.async_call( + DOMAIN, + SERVICE_GET_EVENTS, + service_data={ + "entity_id": "calendar.calendar_1", + **service_data, + }, + blocking=True, + return_response=True, + ) diff --git a/tests/components/holiday/test_calendar.py b/tests/components/holiday/test_calendar.py index db58b7b1f73..6733d38442b 100644 --- a/tests/components/holiday/test_calendar.py +++ b/tests/components/holiday/test_calendar.py @@ -49,7 +49,7 @@ async def test_holiday_calendar_entity( SERVICE_GET_EVENTS, { "entity_id": "calendar.united_states_ak", - "end_date_time": dt_util.now(), + "end_date_time": dt_util.now() + timedelta(hours=1), }, blocking=True, return_response=True, @@ -135,7 +135,7 @@ async def test_default_language( SERVICE_GET_EVENTS, { "entity_id": "calendar.france_bl", - "end_date_time": dt_util.now(), + "end_date_time": dt_util.now() + timedelta(hours=1), }, blocking=True, return_response=True, @@ -164,7 +164,7 @@ async def test_default_language( SERVICE_GET_EVENTS, { "entity_id": "calendar.france_bl", - "end_date_time": dt_util.now(), + "end_date_time": dt_util.now() + timedelta(hours=1), }, blocking=True, return_response=True, @@ -211,7 +211,7 @@ async def test_no_language( SERVICE_GET_EVENTS, { "entity_id": "calendar.albania", - "end_date_time": dt_util.now(), + "end_date_time": dt_util.now() + timedelta(hours=1), }, blocking=True, return_response=True, @@ -308,7 +308,7 @@ async def test_language_not_exist( SERVICE_GET_EVENTS, { "entity_id": "calendar.norge", - "end_date_time": dt_util.now(), + "end_date_time": dt_util.now() + timedelta(hours=1), }, blocking=True, return_response=True, @@ -336,7 +336,7 @@ async def test_language_not_exist( SERVICE_GET_EVENTS, { "entity_id": "calendar.norge", - "end_date_time": dt_util.now(), + "end_date_time": dt_util.now() + timedelta(hours=1), }, blocking=True, return_response=True,